├── .circleci └── config.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE.txt ├── MANIFEST.in ├── ORIGINAL_README.rst ├── PULL_REQUEST_TEMPLATE.md ├── README.rst ├── config ├── database.cfg └── test │ ├── aws.cfg │ └── database.cfg ├── docs ├── Makefile ├── README ├── api │ ├── ansi.rst │ └── env.rst ├── api_reference.rst ├── cli.rst ├── conf.py ├── config.rst ├── images │ └── model.png ├── index.rst ├── io.rst ├── models.rst ├── utils.rst └── www.rst ├── lore ├── __init__.py ├── __main__.py ├── ansi.py ├── callbacks.py ├── data │ └── names.csv ├── dependencies.py ├── encoders.py ├── env.py ├── estimators │ ├── __init__.py │ ├── holt_winters │ │ ├── __init__.py │ │ └── holtwinters.py │ ├── keras.py │ ├── naive.py │ ├── sklearn.py │ └── xgboost.py ├── features │ ├── __init__.py │ ├── base.py │ ├── db.py │ └── s3.py ├── io │ ├── __init__.py │ ├── connection.py │ └── multi_connection_proxy.py ├── metadata │ └── __init__.py ├── models │ ├── __init__.py │ ├── base.py │ ├── keras.py │ ├── naive.py │ ├── sklearn.py │ └── xgboost.py ├── pipelines │ ├── __init__.py │ ├── holdout.py │ ├── iterative.py │ └── time_series.py ├── stores │ ├── __init__.py │ ├── base.py │ ├── disk.py │ ├── ram.py │ ├── redis.py │ └── s3.py ├── tasks │ ├── __init__.py │ └── base.py ├── template │ ├── architecture.py.j2 │ ├── estimator.py.j2 │ ├── features.py.j2 │ ├── init │ │ ├── .env.template │ │ ├── .gitignore │ │ ├── .keras │ │ │ └── keras.json │ │ ├── Procfile │ │ ├── README.rst │ │ ├── app │ │ │ ├── __init__.py │ │ │ ├── estimators │ │ │ │ └── __init__.py │ │ │ ├── extracts │ │ │ │ └── .gitkeep │ │ │ ├── models │ │ │ │ └── __init__.py │ │ │ └── pipelines │ │ │ │ └── __init__.py │ │ ├── config │ │ │ ├── aws.cfg │ │ │ └── database.cfg │ │ ├── notebooks │ │ │ └── .gitkeep │ │ └── tests │ │ │ ├── __init__.py │ │ │ └── unit │ │ │ └── __init__.py │ ├── model.py.j2 │ ├── pipeline.py.j2 │ └── test.py.j2 ├── transformers.py ├── util.py └── www │ └── __init__.py ├── notebooks └── names.ipynb ├── pylintrc ├── release.sh ├── requirements.txt ├── runtime.txt ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── lore_test.py ├── mocks ├── __init__.py ├── features.py ├── models_keras.py ├── models_other.py ├── pipelines.py └── tasks.py └── unit ├── __init__.py ├── io ├── __init__.py ├── test_connection.py └── test_io.py ├── test_encoders.py ├── test_env.py ├── test_estimators.py ├── test_features.py ├── test_main.py ├── test_metadata.py ├── test_models_keras.py ├── test_models_other.py ├── test_pipelines.py ├── test_stores.py └── test_transformers.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.0 2 | 3 | notify: 4 | webhooks: 5 | - url: https://www.instacart.com/circleci/notify 6 | branches: 7 | only: 8 | - /.*/ 9 | 10 | workflows: 11 | version: 2 12 | build: 13 | jobs: 14 | - "python-3.7" 15 | - "python-3.6.5" 16 | 17 | shared: &shared 18 | working_directory: ~/instacart/lore 19 | shell: /bin/bash --login 20 | 21 | environment: 22 | PYENV_ROOT: /home/circleci/.pyenv 23 | CIRCLE_ARTIFACTS: /tmp/circleci-artifacts 24 | CIRCLE_TEST_REPORTS: /tmp/circleci-test-results 25 | 26 | filters: 27 | branches: 28 | ignore: 29 | - /.*\/ci_skip.*/ # end your branch name in /ci_skip and it won't be built by CircleCI 30 | 31 | jobs: 32 | "python-3.7": 33 | <<: *shared 34 | docker: 35 | - image: circleci/python:3.7.6-stretch 36 | environment: 37 | - LORE_PYTHON_VERSION=3.7.6 38 | - image: circleci/postgres:10-ram 39 | environment: 40 | - POSTGRES_USER=circleci 41 | - POSTGRES_HOST_AUTH_METHOD=trust 42 | - POSTGRES_DB=lore_test 43 | steps: 44 | - run: sudo chown -R circleci:circleci /usr/local/bin 45 | - run: sudo chown -R circleci:circleci /usr/local/lib/python* 46 | 47 | - run: mkdir -p $CIRCLE_ARTIFACTS $CIRCLE_TEST_REPORTS 48 | 49 | - restore_cache: 50 | keys: 51 | - source-v2-{{ .Branch }}-{{ .Revision }} 52 | - source-v2-{{ .Branch }}- 53 | - source-v2- 54 | 55 | - checkout 56 | 57 | - save_cache: 58 | key: source-v1-{{ .Branch }}-{{ .Revision }} 59 | paths: 60 | - ".git" 61 | 62 | - run: pip install --upgrade pip 63 | - run: pip install -e . 64 | 65 | - restore_cache: 66 | keys: 67 | - pyenv-3-7-v2-{{ .Branch }}-{{ .Revision }} 68 | - pyenv-3-7-v2-{{ .Branch }}- 69 | - pyenv-3-7-v2- 70 | 71 | - run: lore test -s tests.unit.__init__ 72 | 73 | - save_cache: 74 | key: pyenv-3-7-v2-{{ .Branch }}-{{ .Revision }} 75 | paths: 76 | - /home/circleci/.pyenv 77 | - /home/circleci/python 78 | 79 | - run: lore test 80 | 81 | - store_test_results: 82 | path: /tmp/circleci-test-results 83 | - store_artifacts: 84 | path: /tmp/circleci-artifacts 85 | - store_artifacts: 86 | path: /tmp/circleci-test-results 87 | "python-3.6.5": 88 | <<: *shared 89 | docker: 90 | - image: circleci/python:3.6.5-stretch 91 | environment: 92 | - LORE_PYTHON_VERSION=3.6.5 93 | - image: circleci/postgres:10-ram 94 | environment: 95 | - POSTGRES_USER=circleci 96 | - POSTGRES_HOST_AUTH_METHOD=trust 97 | - POSTGRES_DB=lore_test 98 | steps: 99 | - run: sudo chown -R circleci:circleci /usr/local/bin 100 | - run: sudo chown -R circleci:circleci /usr/local/lib/python* 101 | 102 | - run: mkdir -p $CIRCLE_ARTIFACTS $CIRCLE_TEST_REPORTS 103 | 104 | - restore_cache: 105 | keys: 106 | - source-v2-{{ .Branch }}-{{ .Revision }} 107 | - source-v2-{{ .Branch }}- 108 | - source-v2- 109 | 110 | - checkout 111 | 112 | - save_cache: 113 | key: source-v1-{{ .Branch }}-{{ .Revision }} 114 | paths: 115 | - ".git" 116 | 117 | - run: pip install --upgrade pip 118 | - run: pip install -e . 119 | 120 | - restore_cache: 121 | keys: 122 | - pyenv-3-6-v2-{{ .Branch }}-{{ .Revision }} 123 | - pyenv-3-6-v2-{{ .Branch }}- 124 | - pyenv-3-6-v2- 125 | 126 | - run: lore test -s tests.unit.__init__ 127 | 128 | - save_cache: 129 | key: pyenv-3-6-v2-{{ .Branch }}-{{ .Revision }} 130 | paths: 131 | - /home/circleci/.pyenv 132 | - /home/circleci/python 133 | 134 | - run: lore test 135 | 136 | - store_test_results: 137 | path: /tmp/circleci-test-results 138 | - store_artifacts: 139 | path: /tmp/circleci-artifacts 140 | - store_artifacts: 141 | path: /tmp/circleci-test-results 142 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Setuptools distribution folder. 5 | /dist/ 6 | /build/ 7 | 8 | # Python egg metadata, regenerated from source files by setuptools. 9 | /*.egg-info 10 | 11 | /.eggs 12 | .ipython 13 | 14 | # IDE files 15 | .idea 16 | .vscode 17 | # Jupyter 18 | notebooks/.ipynb_checkpoints/* 19 | 20 | logs 21 | tests/models 22 | tests/data 23 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ### General guidelines and philosophy for contribution 2 | * Include unit tests when you contribute new features, as they help to a) prove that your code works correctly, b) guard against future breaking changes to lower the maintenance cost. 3 | * Bug fixes also generally require unit tests, because the presence of bugs usually indicates insufficient test coverage. 4 | * Keep API compatibility in mind when you change code. See the Release Checklist below. 5 | * Breaking changes will not be accepted until a major version release. 6 | 7 | ### Test locally 8 | CI is run for all PR's. Contributions should be compatible with recent versions of Python 2 & 3. To run tests against a specific version of python: 9 | 10 | ```bash 11 | $ lore test 12 | $ LORE_PYTHON_VERSION=3.6.5 lore test 13 | $ LORE_PYTHON_VERSION=2.7.15 lore test -s tests.unit.test_encoders.TestUniform.test_cardinality 14 | ``` 15 | 16 | You may need to allow requirements.txt to be recalculated when building different virtualenvs for python 2 and 3. 17 | ```bash 18 | $ git checkout -- requirements.txt 19 | ``` 20 | 21 | Install a local version of lore in your project's lore env: 22 | 23 | ```bash 24 | $ git clone https://github.com/instacart/lore ~/repos/lore 25 | $ cd my_project 26 | $ lore pip install -e ~/repos/lore 27 | $ lore test 28 | ``` 29 | 30 | ### Release Checklist: 31 | * Did you add any required properties to Model/Estimator/Pipeline or other Base classes? You need to provide default values for serialized objects during deserialization. 32 | * Did you add any new modules? You need to specify them in setup.py: packages. 33 | * Did you add any new dependencies? Do not add them to setup.py. Instead add them in lore/dependencies.py, and require them only in modules that need it. 34 | 35 | ### Python coding style 36 | Changes should conform to Google Python Style Guide, except feel free to exceed 80 char line limit. 37 | Keep single logical statements on a single line, and use descriptive names. Underscores for functions and variables, camelcase for classes, capitalized underscored constants. In general, new code should follow the style of the existing code closest to it. 38 | 39 | Do not fall prey to the 80 char line length limit. It leads to short, bad names like `q`, `tmp`, `xrt()`. It causes excessive declaration of single use temporary variables with those bad names, that chop logical statements into incoherant expressions. It discourages the use of named function arguments. It pollutes the global namespace by encouraging `from x import Y`, or worse `import package as pk`. It leads to poorly readable line wrapping of function arguments with _insignificant_ whitespace in a whitespace _significant_ language. It makes formatting user facing strings more error prone around whitespace. It breaks urls in docstrings. It costs developers time to format code. The argument for greater readability does not bear out in practice. 40 | 41 | Use pylint to check your Python changes. To install pylint and retrieve Lore's custom style definition: 42 | ```bash 43 | $ pip install pylint 44 | $ wget -O /tmp/pylintrc https://raw.githubusercontent.com/instacart/lore/master/pylintrc 45 | ``` 46 | To check a file with pylint: 47 | ```bash 48 | $ pylint myfile.py 49 | ``` 50 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2018 Instacart 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE.txt 2 | -------------------------------------------------------------------------------- /PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## What 2 | 3 | 4 | ## Why 5 | 6 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ====== 2 | Lore 3 | ====== 4 | 5 | Deprecation Notice 6 | ------------------ 7 | 8 | As of April 2022, Lore has been deprecated at Instacart and will not be supported further. We advise against using Lore in new code. 9 | 10 | The original readme has been moved_. 11 | 12 | .. _moved: ORIGINAL_README.rst 13 | -------------------------------------------------------------------------------- /config/database.cfg: -------------------------------------------------------------------------------- 1 | [MAIN] 2 | url: postgres://localhost/lore_test 3 | use_batch_mode: True 4 | 5 | [MAIN_TWO] 6 | url: postgres://localhost/lore_test 7 | use_batch_mode: True 8 | 9 | [METADATA] 10 | url: sqlite:///data/metadata.sqlite 11 | -------------------------------------------------------------------------------- /config/test/aws.cfg: -------------------------------------------------------------------------------- 1 | [IAM] 2 | role: lore-role 3 | 4 | [ACCESS_KEY] 5 | id: foo 6 | secret: foo 7 | 8 | [BUCKET] 9 | name: lore-test 10 | -------------------------------------------------------------------------------- /config/test/database.cfg: -------------------------------------------------------------------------------- 1 | [MAIN] 2 | url: postgres://localhost/lore_test 3 | use_batch_mode: True 4 | 5 | [MAIN_TWO] 6 | url: postgres://localhost/lore_test 7 | use_batch_mode: True 8 | 9 | [METADATA] 10 | url: postgres://localhost/lore_test 11 | use_batch_mode: True 12 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python3 -msphinx 7 | SPHINXPROJ = lore 8 | SOURCEDIR = . 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/README: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instacart/lore/a14f65a96d0ea2513a35e424b4e16d948115b89c/docs/README -------------------------------------------------------------------------------- /docs/api/ansi.rst: -------------------------------------------------------------------------------- 1 | .. automodule:: lore.ansi 2 | :members: 3 | -------------------------------------------------------------------------------- /docs/api/env.rst: -------------------------------------------------------------------------------- 1 | .. automodule:: lore.env 2 | :members: 3 | -------------------------------------------------------------------------------- /docs/api_reference.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ************************** 3 | 4 | .. automodule:: lore.encoders 5 | :members: 6 | 7 | 8 | -------------------------------------------------------------------------------- /docs/cli.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instacart/lore/a14f65a96d0ea2513a35e424b4e16d948115b89c/docs/cli.rst -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import, unicode_literals 3 | 4 | # 5 | # Configuration file for the Sphinx documentation builder. 6 | # 7 | # This file does only contain a selection of the most common options. For a 8 | # full list see the documentation: 9 | # http://www.sphinx-doc.org/en/stable/config 10 | 11 | # -- Path setup -------------------------------------------------------------- 12 | 13 | # If extensions (or modules to document with autodoc) are in another directory, 14 | # add these directories to sys.path here. If the directory is relative to the 15 | # documentation root, use os.path.abspath to make it absolute, like shown here. 16 | # 17 | import os 18 | import sys 19 | sys.lore_no_env = True 20 | sys.path.insert(0, os.path.abspath('../')) 21 | 22 | import lore 23 | # -- Scrub lore.env information ---------------------------------------------- 24 | version_info = [sys.version_info[0], sys.version_info[1], sys.version_info[2]] 25 | lore.env.HOST = 'localhost' 26 | lore.env.PYTHON_VERSION = '.'.join([str(i) for i in version_info]) 27 | lore.env.PYTHON_VERSION_INFO = version_info 28 | lore.env.ROOT = '.' 29 | lore.env.DATA_DIR = './data' 30 | lore.env.WORK_DIR = '.' 31 | lore.env.MODELS_DIR = './models' 32 | lore.env.LIB_DIR = './libs' 33 | lore.env.ENV_FILE = './.env' 34 | lore.env.HOME = '/home/User' 35 | lore.env.TESTS_DIR = './tests' 36 | lore.env.LOG_DIR = './logs' 37 | lore.env.JUPYTER_KERNEL_PATH = '/' 38 | lore.env.REQUIREMENTS = './requirements.txt' 39 | 40 | 41 | # -- Project information ----------------------------------------------------- 42 | 43 | project = 'Lore' 44 | copyright = '2018, Instacart' 45 | author = 'Montana Low and Jeremy Stanley' 46 | 47 | # The short X.Y version 48 | version = '.'.join(lore.__version__.split('.')[0:1]) 49 | # The full version, including alpha/beta/rc tags 50 | release = lore.__version__ 51 | 52 | 53 | # -- General configuration --------------------------------------------------- 54 | 55 | # If your documentation needs a minimal Sphinx version, state it here. 56 | # 57 | # needs_sphinx = '1.0' 58 | 59 | # Add any Sphinx extension module names here, as strings. They can be 60 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 61 | # ones. 62 | extensions = [ 63 | 'sphinx.ext.autodoc', 64 | 'sphinx.ext.doctest', 65 | 'sphinx.ext.intersphinx', 66 | 'sphinx.ext.todo', 67 | 'sphinx.ext.coverage', 68 | 'sphinx.ext.mathjax', 69 | 'sphinx.ext.ifconfig', 70 | 'sphinx.ext.viewcode', 71 | ] 72 | 73 | # Add any paths that contain templates here, relative to this directory. 74 | templates_path = ['_templates'] 75 | 76 | # The suffix(es) of source filenames. 77 | # You can specify multiple suffix as a list of string: 78 | # 79 | # source_suffix = ['.rst', '.md'] 80 | source_suffix = '.rst' 81 | 82 | # The master toctree document. 83 | master_doc = 'index' 84 | 85 | # The language for content autogenerated by Sphinx. Refer to documentation 86 | # for a list of supported languages. 87 | # 88 | # This is also used if you do content translation via gettext catalogs. 89 | # Usually you set "language" from the command line for these cases. 90 | language = None 91 | 92 | # List of patterns, relative to source directory, that match files and 93 | # directories to ignore when looking for source files. 94 | # This pattern also affects html_static_path and html_extra_path . 95 | exclude_patterns = [] 96 | 97 | # The name of the Pygments (syntax highlighting) style to use. 98 | pygments_style = 'sphinx' 99 | 100 | 101 | # -- Options for HTML output ------------------------------------------------- 102 | 103 | # The theme to use for HTML and HTML Help pages. See the documentation for 104 | # a list of builtin themes. 105 | # 106 | html_theme = 'sphinx_rtd_theme' 107 | 108 | # Theme options are theme-specific and customize the look and feel of a theme 109 | # further. For a list of options available for each theme, see the 110 | # documentation. 111 | # 112 | # html_theme_options = {} 113 | 114 | # Add any paths that contain custom static files (such as style sheets) here, 115 | # relative to this directory. They are copied after the builtin static files, 116 | # so a file named "default.css" will overwrite the builtin "default.css". 117 | html_static_path = ['_static'] 118 | 119 | # Custom sidebar templates, must be a dictionary that maps document names 120 | # to template names. 121 | # 122 | # The default sidebars (for documents that don't match any pattern) are 123 | # defined by theme itself. Builtin themes are using these templates by 124 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 125 | # 'searchbox.html']``. 126 | # 127 | # html_sidebars = {} 128 | 129 | 130 | # -- Options for HTMLHelp output --------------------------------------------- 131 | 132 | # Output file base name for HTML help builder. 133 | htmlhelp_basename = 'Loredoc' 134 | 135 | 136 | # -- Options for LaTeX output ------------------------------------------------ 137 | 138 | latex_elements = { 139 | # The paper size ('letterpaper' or 'a4paper'). 140 | # 141 | # 'papersize': 'letterpaper', 142 | 143 | # The font size ('10pt', '11pt' or '12pt'). 144 | # 145 | # 'pointsize': '10pt', 146 | 147 | # Additional stuff for the LaTeX preamble. 148 | # 149 | # 'preamble': '', 150 | 151 | # Latex figure (float) alignment 152 | # 153 | # 'figure_align': 'htbp', 154 | } 155 | 156 | # Grouping the document tree into LaTeX files. List of tuples 157 | # (source start file, target name, title, 158 | # author, documentclass [howto, manual, or own class]). 159 | latex_documents = [ 160 | (master_doc, 'Lore.tex', 'Lore Documentation', 161 | 'Montana Low and Jeremy Stanley', 'manual'), 162 | ] 163 | 164 | 165 | # -- Options for manual page output ------------------------------------------ 166 | 167 | # One entry per manual page. List of tuples 168 | # (source start file, name, description, authors, manual section). 169 | man_pages = [ 170 | (master_doc, 'lore', 'Lore Documentation', 171 | [author], 1) 172 | ] 173 | 174 | 175 | # -- Options for Texinfo output ---------------------------------------------- 176 | 177 | # Grouping the document tree into Texinfo files. List of tuples 178 | # (source start file, target name, title, author, 179 | # dir menu entry, description, category) 180 | texinfo_documents = [ 181 | (master_doc, 'Lore', 'Lore Documentation', 182 | author, 'Lore', 'Machine Learning Framework for Data Scientists by Engineers', 183 | 'Miscellaneous'), 184 | ] 185 | 186 | 187 | # -- Extension configuration ------------------------------------------------- 188 | 189 | # -- Options for intersphinx extension --------------------------------------- 190 | 191 | # Example configuration for intersphinx: refer to the Python standard library. 192 | intersphinx_mapping = {'https://docs.python.org/': None} 193 | 194 | # -- Options for todo extension ---------------------------------------------- 195 | 196 | # If true, `todo` and `todoList` produce output, else they produce nothing. 197 | todo_include_todos = True 198 | -------------------------------------------------------------------------------- /docs/config.rst: -------------------------------------------------------------------------------- 1 | Configuration 2 | ******************************** 3 | 4 | Table of Contents 5 | ================================ 6 | 7 | .. toctree:: 8 | :maxdepth: 2 9 | :caption: Contents: 10 | 11 | Databases 12 | 13 | 14 | Databases 15 | ========= 16 | 17 | Amazon Web Services 18 | =================== 19 | 20 | 21 | Indices and tables 22 | ================== 23 | 24 | * :ref:`genindex` 25 | * :ref:`modindex` 26 | * :ref:`search` 27 | 28 | -------------------------------------------------------------------------------- /docs/images/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instacart/lore/a14f65a96d0ea2513a35e424b4e16d948115b89c/docs/images/model.png -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Lore 2 | ******************************** 3 | 4 | Table of Contents 5 | ================================ 6 | 7 | .. toctree:: 8 | :maxdepth: 2 9 | :caption: Contents: 10 | 11 | cli 12 | www 13 | models 14 | io 15 | utils 16 | api_reference 17 | 18 | 19 | 20 | Indices and tables 21 | ================== 22 | 23 | * :ref:`genindex` 24 | * :ref:`modindex` 25 | * :ref:`search` 26 | -------------------------------------------------------------------------------- /docs/io.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instacart/lore/a14f65a96d0ea2513a35e424b4e16d948115b89c/docs/io.rst -------------------------------------------------------------------------------- /docs/models.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instacart/lore/a14f65a96d0ea2513a35e424b4e16d948115b89c/docs/models.rst -------------------------------------------------------------------------------- /docs/utils.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instacart/lore/a14f65a96d0ea2513a35e424b4e16d948115b89c/docs/utils.rst -------------------------------------------------------------------------------- /docs/www.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instacart/lore/a14f65a96d0ea2513a35e424b4e16d948115b89c/docs/www.rst -------------------------------------------------------------------------------- /lore/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import, unicode_literals 3 | 4 | import logging 5 | import sys 6 | import os 7 | 8 | import lore.dependencies 9 | from lore import env, util, ansi 10 | from lore.util import timer 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | __author__ = 'Montana Low and Jeremy Stanley' 17 | __copyright__ = 'Copyright © 2018, Instacart' 18 | __credits__ = ['Montana Low', 'Jeremy Stanley', 'Emmanuel Turlay', 'Shrikar Archak', 'Ganesh Krishnan'] 19 | __license__ = 'MIT' 20 | __version__ = '0.8.6' 21 | __maintainer__ = 'Montana Low' 22 | __email__ = 'montana@instacart.com' 23 | __status__ = 'Development Status :: 4 - Beta' 24 | 25 | 26 | def banner(): 27 | import socket 28 | import getpass 29 | 30 | return '%s in %s on %s with %s & %s' % ( 31 | ansi.foreground(ansi.GREEN, env.APP), 32 | ansi.foreground(env.COLOR, env.NAME), 33 | ansi.foreground( 34 | ansi.CYAN, 35 | getpass.getuser() + '@' + socket.gethostname() 36 | ), 37 | ansi.foreground(ansi.YELLOW, 'Python ' + env.PYTHON_VERSION), 38 | ansi.foreground(ansi.YELLOW, 'Lore ' + __version__) 39 | ) 40 | 41 | 42 | lore_no_env = False 43 | if hasattr(sys, 'lore_no_env'): 44 | lore_no_env = sys.lore_no_env 45 | 46 | 47 | no_env_commands = ['--version', 'install', 'init', 'server'] 48 | if len(sys.argv) > 1 and os.path.basename(sys.argv[0]) in ['lore', 'lore.exe'] and sys.argv[1] in no_env_commands: 49 | lore_no_env = True 50 | 51 | if '--no-env' in sys.argv: 52 | lore_no_env = True 53 | 54 | if not lore_no_env: 55 | # everyone else gets validated and launched on import 56 | env.validate() 57 | env.launch() 58 | 59 | if env.launched(): 60 | print(banner()) 61 | logger.info(banner()) 62 | logger.debug('python environment: %s' % env.PREFIX) 63 | 64 | if not lore_no_env: 65 | with timer('check requirements', logging.DEBUG): 66 | env.check_requirements() 67 | -------------------------------------------------------------------------------- /lore/ansi.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | ANSI makes it easy! 4 | ******************* 5 | 6 | :any:`lore.ansi` makes formatting text output super simple! Lore doesn't have 7 | much of a UI. Text output should be excellent. 8 | 9 | .. role:: strike 10 | :class: strike 11 | """ 12 | from __future__ import absolute_import, print_function, unicode_literals 13 | 14 | import platform 15 | 16 | RESET = 0 #: Game over! 17 | 18 | BOLD = 1 #: For people with a heavy hand. 19 | FAINT = 2 #: For people with a light touch. 20 | ITALIC = 3 #: Are Italian's emphatic, or is Italy slanted? `Etymology `_ is fun. 21 | UNDERLINE = 4 #: It's got a line under it, no need for a PhD here. 22 | STROBE = 5 #: For sadists looking to cause seizures. Doesn't work except on masochist's platforms. 23 | BLINK = 6 #: For that patiently waiting cursor effect. Also doesn't work, since sadists ruined it for everyone. 24 | INVERSE = 7 #: Today is backwards day. 25 | CONCEAL = 8 #: Why would you do this‽ Your attempt has been logged, and will be reported to the authorities. 26 | STRIKE = 9 #: Adopt that sense of humility, let other people know you w̶e̶r̶e̶ ̶w̶r̶o̶n̶g learned from experience. 27 | 28 | BLACK = 30 #: If you gaze long into an abyss, the abyss will gaze back into you. 29 | RED = 31 #: Hot and loud. Like a fire engine, anger, or you've just bitten your cheek for the third time today. 30 | GREEN = 32 #: The most refreshingly natural color. Growth and softness. 31 | YELLOW = 33 #: Daisies and deadly poison dart frogs. Salted butter and lightning. Like scotch filtered through gold foil. 32 | BLUE = 34 #: Skies, oceans, infinite depths. The color of hope and melancholy. 33 | MAGENTA = 35 #: For latin salsa dresses with matching shoes. Also, the radiant color of T brown dwarf stars as long as sodium and potassium atoms absorb the :any:`GREEN` light in the spectrum. 34 | CYAN = 36 #: Only printers who prefer CMYK over RGB would name this color. It's :any:`BLUE` stripped of soul, injected with 10,000 volts. A true Frankenstein's monster. 35 | WHITE = 37 #: The sum of all colors, to the point there is no color left at all. Floats nicely in the abyss. 36 | DEFAULT = 39 #: You get 3 guesses what this color is, and the first 2 don't count. 37 | 38 | 39 | if platform.system() == 'Windows': 40 | ### 41 | # This is an unfortunate hack to enable ansi output on Windows 42 | # https://bugs.python.org/issue30075 43 | import msvcrt 44 | import ctypes 45 | import os 46 | 47 | from ctypes import wintypes # pylint: disable=ungrouped-imports 48 | 49 | kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) 50 | 51 | ERROR_INVALID_PARAMETER = 0x0057 52 | ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x0004 53 | 54 | def _check_bool(result, func, args): 55 | if not result: 56 | raise ctypes.WinError(ctypes.get_last_error()) 57 | return args 58 | 59 | LPDWORD = ctypes.POINTER(wintypes.DWORD) 60 | kernel32.GetConsoleMode.errcheck = _check_bool 61 | kernel32.GetConsoleMode.argtypes = (wintypes.HANDLE, LPDWORD) 62 | kernel32.SetConsoleMode.errcheck = _check_bool 63 | kernel32.SetConsoleMode.argtypes = (wintypes.HANDLE, wintypes.DWORD) 64 | 65 | def set_conout_mode(new_mode, mask=0xffffffff): 66 | # don't assume StandardOutput is a console. 67 | # open CONOUT$ instead 68 | fdout = os.open('CONOUT$', os.O_RDWR) 69 | try: 70 | hout = msvcrt.get_osfhandle(fdout) 71 | old_mode = wintypes.DWORD() 72 | kernel32.GetConsoleMode(hout, ctypes.byref(old_mode)) 73 | mode = (new_mode & mask) | (old_mode.value & ~mask) 74 | kernel32.SetConsoleMode(hout, mode) 75 | return old_mode.value 76 | finally: 77 | os.close(fdout) 78 | 79 | def enable_vt_mode(): 80 | mode = mask = ENABLE_VIRTUAL_TERMINAL_PROCESSING 81 | try: 82 | return set_conout_mode(mode, mask) 83 | except WindowsError as e: # pylint: disable=undefined-variable 84 | if e.winerror == ERROR_INVALID_PARAMETER: 85 | raise NotImplementedError 86 | raise 87 | import atexit 88 | atexit.register(set_conout_mode, enable_vt_mode()) 89 | 90 | 91 | def debug(content='DEBUG'): 92 | """ debug style 93 | 94 | :param content: Whatever you want to say... 95 | :type content: unicode 96 | :return: ansi string 97 | :rtype: unicode 98 | """ 99 | return gray(7, content) 100 | 101 | 102 | def info(content='INFO'): 103 | """ info style 104 | 105 | :param content: Whatever you want to say... 106 | :type content: unicode 107 | :return: ansi string 108 | :rtype: unicode 109 | """ 110 | return foreground(BLUE, content) 111 | 112 | 113 | def warning(content='WARNING'): 114 | """ warning style 115 | 116 | :param content: Whatever you want to say... 117 | :type content: unicode 118 | :return: ansi string 119 | :rtype: unicode 120 | """ 121 | return foreground(MAGENTA, content) 122 | 123 | 124 | def success(content='SUCCESS'): 125 | """ success style 126 | 127 | :param content: Whatever you want to say... 128 | :type content: unicode 129 | :return: ansi string 130 | :rtype: unicode 131 | """ 132 | return foreground(GREEN, content) 133 | 134 | 135 | def error(content='ERROR'): 136 | """ error style 137 | 138 | :param content: Whatever you want to say... 139 | :type content: unicode 140 | :return: ansi string 141 | :rtype: unicode 142 | """ 143 | return foreground(RED, content) 144 | 145 | 146 | def critical(content='CRITICAL'): 147 | """ for really big fuck ups, not to be used lightly. 148 | 149 | :param content: Whatever you want to say... 150 | :type content: unicode 151 | :return: ansi string 152 | :rtype: unicode 153 | """ 154 | return blink(foreground(bright(RED), content)) 155 | 156 | 157 | def foreground(color, content, readline=False): 158 | """ Color the text of the content 159 | 160 | :param color: pick a constant, any constant 161 | :type color: int 162 | :param content: Whatever you want to say... 163 | :type content: unicode 164 | :return: ansi string 165 | :rtype: unicode 166 | """ 167 | return encode(color, readline=readline) + content + encode(DEFAULT, readline=readline) 168 | 169 | 170 | def background(color, content): 171 | """ Color the background of the content 172 | 173 | :param color: pick a constant, any constant 174 | :type color: int 175 | :param content: Whatever you want to say... 176 | :type content: unicode 177 | :return: ansi string 178 | :rtype: unicode 179 | """ 180 | return encode(color + 10) + content + encode(DEFAULT + 10) 181 | 182 | 183 | def bright(color): 184 | """ Brighten a color 185 | 186 | :param color: pick a constant, any constant 187 | :type color: int 188 | :type content: unicode 189 | :return: brighter version of the color 190 | :rtype: unicode 191 | """ 192 | return color + 60 193 | 194 | 195 | def gray(level, content): 196 | """ Grayscale 197 | 198 | :param level: [0-15] 0 is almost black, 15 is nearly white 199 | :type level: int 200 | :param content: Whatever you want to say... 201 | :type content: unicode 202 | :return: ansi string 203 | :rtype: unicode 204 | """ 205 | return encode('38;5;%i' % (232 + level)) + content + encode(DEFAULT) 206 | 207 | 208 | def rgb(red, green, blue, content): 209 | """ Colors a content using rgb for h 210 | :param red: [0-5] 211 | :type red: int 212 | :param green: [0-5] 213 | :type green: int 214 | :param blue: [0-5] 215 | :type blue: int 216 | :param content: Whatever you want to say... 217 | :type content: unicode 218 | :return: ansi string 219 | :rtype: unicode 220 | """ 221 | color = 16 + 36 * red + 6 * green + blue 222 | return encode('38;5;' + str(color)) + content + encode(DEFAULT) 223 | 224 | 225 | def bold(content): 226 | """ Bold content 227 | 228 | :param content: Whatever you want to say... 229 | :type content: unicode 230 | :return: ansi string 231 | :rtype: unicode 232 | """ 233 | return style(BOLD, content) 234 | 235 | 236 | def faint(content): 237 | """ Faint content 238 | 239 | :param content: Whatever you want to say... 240 | :type content: unicode 241 | :return: ansi string 242 | :rtype: unicode 243 | """ 244 | return style(FAINT, content) 245 | 246 | 247 | def italic(content): 248 | """ Italic content 249 | 250 | :param content: Whatever you want to say... 251 | :type content: unicode 252 | :return: ansi string 253 | :rtype: unicode 254 | """ 255 | return style(ITALIC, content) 256 | 257 | 258 | def underline(content): 259 | """ Underline content 260 | 261 | :param content: Whatever you want to say... 262 | :type content: unicode 263 | :return: ansi string 264 | :rtype: unicode 265 | """ 266 | return style(UNDERLINE, content) 267 | 268 | 269 | def strobe(content): 270 | """ Quickly blinking content 271 | 272 | :param content: Whatever you want to say... 273 | :type content: unicode 274 | :return: ansi string 275 | :rtype: unicode 276 | """ 277 | return style(STROBE, content) 278 | 279 | 280 | def blink(content): 281 | """ Slowing blinking content 282 | 283 | :param content: Whatever you want to say... 284 | :type content: unicode 285 | :return: ansi string 286 | :rtype: unicode 287 | """ 288 | return style(BLINK, content) 289 | 290 | 291 | def inverse(content): 292 | """ Inverted content 293 | 294 | :param content: Whatever you want to say... 295 | :type content: unicode 296 | :return: ansi string 297 | :rtype: unicode 298 | """ 299 | return style(BLINK, content) 300 | 301 | 302 | def conceal(content): 303 | """ Why do you persist in this nonsense? 304 | 305 | :param content: Whatever you want to say... 306 | :type content: unicode 307 | :return: ansi string 308 | :rtype: unicode 309 | """ 310 | return style(CONCEAL, content) 311 | 312 | 313 | def strike(content): 314 | """ Strike through content 315 | 316 | :param content: Whatever you want to say... 317 | :type content: unicode 318 | :return: ansi string 319 | :rtype: unicode 320 | """ 321 | return style(STRIKE, content) 322 | 323 | 324 | def reset(): 325 | """ Remove any active ansi styles 326 | 327 | :return: string resetting marker 328 | :rtype: unicode 329 | """ 330 | return encode(RESET) 331 | 332 | 333 | def style(effect, content): 334 | """ add a particular style to the content 335 | 336 | :param effect: style 337 | :type effect: int 338 | :param content: Whatever you want to say... string 339 | :type content: unicode 340 | :return: ansi string 341 | :rtype: unicode 342 | """ 343 | return encode(effect) + content + encode(RESET) 344 | 345 | 346 | def encode(code, readline=False): 347 | """ Adds escape and control characters for ANSI codes 348 | 349 | :param code: pick a constant, any constant 350 | :type code: int 351 | :param readline: add readline compatibility, which causes bugs in other formats 352 | :type content: unicode 353 | :return: ansi string 354 | :rtype: unicode 355 | """ 356 | if readline: 357 | return '\001\033[' + str(code) + 'm\002' 358 | 359 | return '\033[' + str(code) + 'm' 360 | -------------------------------------------------------------------------------- /lore/callbacks.py: -------------------------------------------------------------------------------- 1 | import lore 2 | from lore.env import require 3 | from lore.util import timer 4 | 5 | import logging 6 | from datetime import datetime 7 | 8 | require(lore.dependencies.KERAS) 9 | import keras.callbacks 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class ReloadBest(keras.callbacks.ModelCheckpoint): 16 | def __init__( 17 | self, 18 | filepath, 19 | monitor='val_loss', 20 | mode='auto', 21 | ): 22 | super(ReloadBest, self).__init__( 23 | filepath=filepath, 24 | monitor=monitor, 25 | verbose=0, 26 | mode=mode, 27 | save_best_only=False, 28 | save_weights_only=True, 29 | period=1 30 | ) 31 | self.train_loss = None 32 | self.validate_loss = None 33 | self.best_epoch = None 34 | self.train_begin = None 35 | 36 | def on_train_begin(self, logs=None): 37 | super(ReloadBest, self).on_train_begin(logs) 38 | 39 | self.train_begin = datetime.utcnow() 40 | logger.info('=============================================') 41 | logger.info('| epoch | time | train | validate |') 42 | logger.info('---------------------------------------------') 43 | 44 | def on_train_end(self, logs=None): 45 | super(ReloadBest, self).on_train_end(logs) 46 | logger.info('=============================================') 47 | if self.best_epoch is not None: 48 | logger.debug('best epoch: %i' % self.best_epoch) 49 | with timer('load best epoch'): 50 | self.model.load_weights( 51 | self.filepath.format(epoch=self.best_epoch) 52 | ) 53 | 54 | def on_epoch_end(self, epoch, logs=None): 55 | super(ReloadBest, self).on_epoch_end(epoch, logs) 56 | time = datetime.utcnow() - self.train_begin 57 | train_loss = logs.get('loss') 58 | validate_loss = logs.get('val_loss') 59 | if validate_loss: 60 | if self.validate_loss is None or self.validate_loss > validate_loss: 61 | self.best_epoch = epoch + 1 62 | self.train_loss = train_loss 63 | self.validate_loss = validate_loss 64 | else: 65 | logger.error('No val_loss in logs, setting to NaN') 66 | validate_loss = float('nan') 67 | logger.info('| %8i | %8s | %8.4f | %8.4f |' % ( 68 | epoch, str(time).split('.', 2)[0], train_loss, validate_loss) 69 | ) 70 | -------------------------------------------------------------------------------- /lore/dependencies.py: -------------------------------------------------------------------------------- 1 | DATEUTIL = ['python-dateutil>=2.1, <2.7.0'] 2 | FLASK = ['flask>=0.11.0, <0.12.99'] 3 | FUTURE = ['future>=0.15, <0.16.99'] 4 | INFLECTION = ['inflection>=0.3, <0.3.99'] 5 | JINJA = ['Jinja2>=2.9.0, <2.10.0'] 6 | JUPYTER = [ 7 | 'jupyter>=1.0, <1.0.99', 8 | 'jupyter-core>=4.4.0, <4.4.99', 9 | ] 10 | NUMPY = ['numpy>=1.14, <1.14.99'] 11 | PANDAS = ['pandas>=0.20, <0.23.99, !=0.22.0'] 12 | TABULATE = ['tabulate>=0.7.5, <0.8.99'] 13 | SHAP = ['shap>=0.12.0, <0.12.99'] 14 | 15 | SQL = ['sqlalchemy>=1.2.0b3, <1.2.99', 'sqlalchemy-migrate>=0.11, <0.11.99'] 16 | SNOWFLAKE = [ 17 | 'snowflake-connector-python>=2.0.2, <3.0.0', 18 | 'snowflake-sqlalchemy>=1.1.0, <1.2.0', 19 | ] 20 | POSTGRES = ['psycopg2>=2.7, <2.7.99'] + SQL 21 | REDSHIFT = ['sqlalchemy-redshift>=0.7, <0.7.99'] + SQL 22 | REDIS = ['redis>=2.10, <2.10.99'] 23 | S3 = ['boto3>=1.4, <1.7.99'] + DATEUTIL 24 | SMART_OPEN = ['smart-open>=1.5, <1.5.99'] + S3 25 | GEOIP = ['geoip2'] 26 | H5PY = ['h5py>=2.7, <2.8.99',] 27 | KERAS = [ 28 | 'Keras>=2.0.9, <2.1.99', 29 | 'tensorflow>=1.3, <2.0.0', 30 | 'dill>=0.2, <0.2.99', 31 | 'bleach==1.5.0', 32 | 'html5lib==0.9999999', 33 | 'pydot>=1.2.4, <1.2.99', 34 | 'graphviz>=0.8.2, <0.8.99'] + H5PY 35 | XGBOOST = ['xgboost>=0.72, <0.80'] 36 | SKLEARN = ['scikit-learn>=0.19, <0.19.99'] 37 | 38 | ALL = list(set( 39 | DATEUTIL + 40 | FLASK + 41 | FUTURE + 42 | INFLECTION + 43 | JINJA + 44 | JUPYTER + 45 | NUMPY + 46 | PANDAS + 47 | TABULATE + 48 | SHAP + 49 | SQL + 50 | SNOWFLAKE + 51 | POSTGRES + 52 | REDSHIFT + 53 | REDIS + 54 | S3 + 55 | SMART_OPEN + 56 | GEOIP + 57 | H5PY + 58 | KERAS + 59 | XGBOOST + 60 | SKLEARN 61 | )) 62 | 63 | TEST = ALL + [ 64 | 'moto>=1.1, <1.3.99' 65 | ] 66 | 67 | DOC = ALL + [ 68 | 'sphinx', 69 | 'sphinx-autobuild', 70 | 'sphinx_rtd_theme' 71 | ] 72 | -------------------------------------------------------------------------------- /lore/estimators/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | lore.estimators 3 | """ 4 | from __future__ import absolute_import 5 | 6 | from abc import ABCMeta, abstractmethod 7 | import logging 8 | from sklearn.base import BaseEstimator 9 | 10 | from lore.util import timed, before_after_callbacks 11 | 12 | 13 | class Base(BaseEstimator): 14 | """Base class for estimators""" 15 | 16 | __metaclass__ = ABCMeta 17 | 18 | @before_after_callbacks 19 | @timed(logging.INFO) 20 | @abstractmethod 21 | def fit(self): 22 | pass 23 | 24 | @before_after_callbacks 25 | @timed(logging.INFO) 26 | @abstractmethod 27 | def predict(self): 28 | pass 29 | 30 | @before_after_callbacks 31 | @timed(logging.INFO) 32 | @abstractmethod 33 | def evaluate(self): 34 | pass 35 | 36 | @before_after_callbacks 37 | @timed(logging.INFO) 38 | @abstractmethod 39 | def score(self): 40 | pass 41 | -------------------------------------------------------------------------------- /lore/estimators/holt_winters/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from lore.env import require 4 | from lore.util import timed 5 | from lore.estimators.holt_winters.holtwinters import additive 6 | 7 | require(lore.dependencies.SKLEARN) 8 | 9 | from sklearn.base import BaseEstimator 10 | 11 | 12 | class HoltWinters(BaseEstimator): 13 | 14 | def __init__(self, **kwargs): 15 | super(HoltWinters, self).__init__() 16 | self.periodicity = kwargs.get('periodicity') 17 | self.forecasts = kwargs.get('days_to_forecast') 18 | self.kwargs = kwargs 19 | self.params = None 20 | 21 | @timed(logging.INFO) 22 | def fit(self, x, y=None): 23 | results = additive(x, self.periodicity, self.forecasts, 24 | alpha=self.kwargs.get('alpha'), 25 | beta=self.kwargs.get('beta'), 26 | gamma=self.kwargs.get('gamma')) 27 | self.params = {'alpha': results[1], 'beta': results[2], 'gamma': results[3]} 28 | self.rmse = results[4] 29 | return {'alpha': results[1], 'beta': results[2], 'gamma': results[3], 'RMSE': self.rmse} 30 | 31 | @timed(logging.INFO) 32 | def predict(self, X): 33 | return additive(X, self.periodicity, self.forecasts, **self.params)[0] 34 | -------------------------------------------------------------------------------- /lore/estimators/holt_winters/holtwinters.py: -------------------------------------------------------------------------------- 1 | #The MIT License (MIT) 2 | # 3 | #Copyright (c) 2015 Andre Queiroz 4 | # 5 | #Permission is hereby granted, free of charge, to any person obtaining a copy 6 | #of this software and associated documentation files (the "Software"), to deal 7 | #in the Software without restriction, including without limitation the rights 8 | #to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | #copies of the Software, and to permit persons to whom the Software is 10 | #furnished to do so, subject to the following conditions: 11 | # 12 | #The above copyright notice and this permission notice shall be included in 13 | #all copies or substantial portions of the Software. 14 | # 15 | #THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | #IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | #FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | #AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | #LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | #OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | #THE SOFTWARE. 22 | # 23 | # Holt-Winters algorithms to forecasting 24 | # Coded in Python 2 by: Andre Queiroz 25 | # Description: This module contains three exponential smoothing algorithms. They are Holt's linear trend method and Holt-Winters seasonal methods (additive and multiplicative). 26 | # References: 27 | # Hyndman, R. J.; Athanasopoulos, G. (2013) Forecasting: principles and practice. http://otexts.com/fpp/. Accessed on 07/03/2013. 28 | # Byrd, R. H.; Lu, P.; Nocedal, J. A Limited Memory Algorithm for Bound Constrained Optimization, (1995), SIAM Journal on Scientific and Statistical Computing, 16, 5, pp. 1190-1208. 29 | 30 | # https://gist.github.com/andrequeiroz/5888967 31 | 32 | from __future__ import division 33 | from sys import exit 34 | from math import sqrt 35 | from numpy import array 36 | from scipy.optimize import fmin_l_bfgs_b 37 | 38 | def RMSE(params, *args): 39 | 40 | Y = args[0] 41 | type = args[1] 42 | rmse = 0 43 | 44 | if type == 'linear': 45 | 46 | alpha, beta = params 47 | a = [Y[0]] 48 | b = [Y[1] - Y[0]] 49 | y = [a[0] + b[0]] 50 | 51 | for i in range(len(Y)): 52 | 53 | a.append(alpha * Y[i] + (1 - alpha) * (a[i] + b[i])) 54 | b.append(beta * (a[i + 1] - a[i]) + (1 - beta) * b[i]) 55 | y.append(a[i + 1] + b[i + 1]) 56 | 57 | else: 58 | 59 | alpha, beta, gamma = params 60 | m = args[2] 61 | a = [sum(Y[0:m]) / float(m)] 62 | b = [(sum(Y[m:2 * m]) - sum(Y[0:m])) / m ** 2] 63 | 64 | if type == 'additive': 65 | 66 | s = [Y[i] - a[0] for i in range(m)] 67 | y = [a[0] + b[0] + s[0]] 68 | 69 | for i in range(len(Y)): 70 | 71 | a.append(alpha * (Y[i] - s[i]) + (1 - alpha) * (a[i] + b[i])) 72 | b.append(beta * (a[i + 1] - a[i]) + (1 - beta) * b[i]) 73 | s.append(gamma * (Y[i] - a[i] - b[i]) + (1 - gamma) * s[i]) 74 | y.append(a[i + 1] + b[i + 1] + s[i + 1]) 75 | 76 | elif type == 'multiplicative': 77 | 78 | s = [Y[i] / a[0] for i in range(m)] 79 | y = [(a[0] + b[0]) * s[0]] 80 | 81 | for i in range(len(Y)): 82 | 83 | a.append(alpha * (Y[i] / s[i]) + (1 - alpha) * (a[i] + b[i])) 84 | b.append(beta * (a[i + 1] - a[i]) + (1 - beta) * b[i]) 85 | s.append(gamma * (Y[i] / (a[i] + b[i])) + (1 - gamma) * s[i]) 86 | y.append((a[i + 1] + b[i + 1]) * s[i + 1]) 87 | 88 | else: 89 | 90 | exit('Type must be either linear, additive or multiplicative') 91 | 92 | rmse = sqrt(sum([(m - n) ** 2 for m, n in zip(Y, y[:-1])]) / len(Y)) 93 | 94 | return rmse 95 | 96 | def linear(x, fc, alpha = None, beta = None): 97 | 98 | Y = x[:] 99 | 100 | if (alpha == None or beta == None): 101 | 102 | initial_values = array([0.3, 0.1]) 103 | boundaries = [(0, 1), (0, 1)] 104 | type = 'linear' 105 | 106 | parameters = fmin_l_bfgs_b(RMSE, x0 = initial_values, args = (Y, type), bounds = boundaries, approx_grad = True) 107 | alpha, beta = parameters[0] 108 | 109 | a = [Y[0]] 110 | b = [Y[1] - Y[0]] 111 | y = [a[0] + b[0]] 112 | rmse = 0 113 | 114 | for i in range(len(Y) + fc): 115 | 116 | if i == len(Y): 117 | Y.append(a[-1] + b[-1]) 118 | 119 | a.append(alpha * Y[i] + (1 - alpha) * (a[i] + b[i])) 120 | b.append(beta * (a[i + 1] - a[i]) + (1 - beta) * b[i]) 121 | y.append(a[i + 1] + b[i + 1]) 122 | 123 | rmse = sqrt(sum([(m - n) ** 2 for m, n in zip(Y[:-fc], y[:-fc - 1])]) / len(Y[:-fc])) 124 | 125 | return Y[-fc:], alpha, beta, rmse 126 | 127 | def additive(x, m, fc, alpha = None, beta = None, gamma = None): 128 | 129 | Y = x[:] 130 | 131 | if (alpha == None or beta == None or gamma == None): 132 | 133 | initial_values = array([0.3, 0.1, 0.1]) 134 | boundaries = [(0, 1), (0, 1), (0, 1)] 135 | type = 'additive' 136 | 137 | parameters = fmin_l_bfgs_b(RMSE, x0 = initial_values, args = (Y, type, m), bounds = boundaries, approx_grad = True) 138 | alpha, beta, gamma = parameters[0] 139 | 140 | a = [sum(Y[0:m]) / float(m)] 141 | b = [(sum(Y[m:2 * m]) - sum(Y[0:m])) / m ** 2] 142 | s = [Y[i] - a[0] for i in range(m)] 143 | y = [a[0] + b[0] + s[0]] 144 | rmse = 0 145 | 146 | for i in range(len(Y) + fc): 147 | 148 | if i == len(Y): 149 | Y.append(a[-1] + b[-1] + s[-m]) 150 | 151 | a.append(alpha * (Y[i] - s[i]) + (1 - alpha) * (a[i] + b[i])) 152 | b.append(beta * (a[i + 1] - a[i]) + (1 - beta) * b[i]) 153 | s.append(gamma * (Y[i] - a[i] - b[i]) + (1 - gamma) * s[i]) 154 | y.append(a[i + 1] + b[i + 1] + s[i + 1]) 155 | 156 | rmse = sqrt(sum([(m - n) ** 2 for m, n in zip(Y[:-fc], y[:-fc - 1])]) / len(Y[:-fc])) 157 | 158 | return Y[-fc:], alpha, beta, gamma, rmse 159 | 160 | def multiplicative(x, m, fc, alpha = None, beta = None, gamma = None): 161 | 162 | Y = x[:] 163 | 164 | if (alpha == None or beta == None or gamma == None): 165 | 166 | initial_values = array([0.0, 1.0, 0.0]) 167 | boundaries = [(0, 1), (0, 1), (0, 1)] 168 | type = 'multiplicative' 169 | 170 | parameters = fmin_l_bfgs_b(RMSE, x0 = initial_values, args = (Y, type, m), bounds = boundaries, approx_grad = True) 171 | alpha, beta, gamma = parameters[0] 172 | 173 | a = [sum(Y[0:m]) / float(m)] 174 | b = [(sum(Y[m:2 * m]) - sum(Y[0:m])) / m ** 2] 175 | s = [Y[i] / a[0] for i in range(m)] 176 | y = [(a[0] + b[0]) * s[0]] 177 | rmse = 0 178 | 179 | for i in range(len(Y) + fc): 180 | 181 | if i == len(Y): 182 | Y.append((a[-1] + b[-1]) * s[-m]) 183 | 184 | a.append(alpha * (Y[i] / s[i]) + (1 - alpha) * (a[i] + b[i])) 185 | b.append(beta * (a[i + 1] - a[i]) + (1 - beta) * b[i]) 186 | s.append(gamma * (Y[i] / (a[i] + b[i])) + (1 - gamma) * s[i]) 187 | y.append((a[i + 1] + b[i + 1]) * s[i + 1]) 188 | 189 | rmse = sqrt(sum([(m - n) ** 2 for m, n in zip(Y[:-fc], y[:-fc - 1])]) / len(Y[:-fc])) 190 | 191 | return Y[-fc:], alpha, beta, gamma, rmse -------------------------------------------------------------------------------- /lore/estimators/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Naive Estimator 4 | **************** 5 | A naive estimator is a useful baseline against which to benchmark more complex models. 6 | A naive estimator will return the mean of the outcome for regression models and 7 | the plurality class for classification models. Note that currently only binary classification 8 | is implemented. For binary classifiers, the majority class will be returned 9 | """ 10 | from __future__ import absolute_import 11 | import inspect 12 | import logging 13 | import lore 14 | from lore.env import require 15 | import lore.estimators 16 | from lore.util import timed, before_after_callbacks 17 | 18 | require(lore.dependencies.NUMPY) 19 | import numpy 20 | 21 | 22 | class Base(lore.estimators.Base): 23 | """Base class for the Naive estimator. Implements functionality common to all Naive models""" 24 | def __init__(self): 25 | super(Base, self).__init__() 26 | 27 | @before_after_callbacks 28 | @timed(logging.INFO) 29 | def fit(self, x, y, **kwargs): 30 | """ 31 | Fit a naive model 32 | :param x: Predictors to use for fitting the data (this will not be used in naive models) 33 | :param y: Outcome 34 | """ 35 | self.mean = numpy.mean(y) 36 | return {} 37 | 38 | @before_after_callbacks 39 | @timed(logging.INFO) 40 | def predict(self, dataframe): 41 | """ 42 | .. _naive_base_predict 43 | Predict using the model 44 | :param dataframe: Dataframe against which to make predictions 45 | """ 46 | pass 47 | 48 | @before_after_callbacks 49 | @timed(logging.INFO) 50 | def evaluate(self, x, y): 51 | # TODO 52 | return 0 53 | 54 | @before_after_callbacks 55 | @timed(logging.INFO) 56 | def score(self, x, y): 57 | # TODO 58 | return 0 59 | 60 | 61 | class Naive(Base): 62 | def __init__(self): 63 | frame, filename, line_number, function_name, lines, index = inspect.stack()[1] 64 | super(Naive, self).__init__() 65 | 66 | 67 | class Regression(Base): 68 | @before_after_callbacks 69 | @timed(logging.INFO) 70 | def predict(self, dataframe): 71 | """See :ref:`Base Estimator for Naive _naive_base_predict`""" 72 | return numpy.full(dataframe.shape[0], self.mean) 73 | 74 | 75 | class BinaryClassifier(Base): 76 | @before_after_callbacks 77 | @timed(logging.INFO) 78 | def predict(self, dataframe): 79 | """See :ref:`Base Estimator for Naive _naive_base_predict`""" 80 | if self.mean > 0.5: 81 | return numpy.ones(dataframe.shape[0]) 82 | else: 83 | return numpy.zeros(dataframe.shape[0]) 84 | 85 | @before_after_callbacks 86 | @timed(logging.INFO) 87 | def predict_proba(self, dataframe): 88 | """Predict probabilities using the model 89 | :param dataframe: Dataframe against which to make predictions 90 | """ 91 | ret = numpy.ones((dataframe.shape[0], 2)) 92 | ret[:, 0] = (1 - self.mean) 93 | ret[:, 1] = self.mean 94 | return ret 95 | -------------------------------------------------------------------------------- /lore/estimators/sklearn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | scikit-learn Estimator 4 | **************** 5 | This estimator allows you to use any scikit-learn estimator of your choice. 6 | Note that the underlying estimator can always be accessed as ``Base(estimator).sklearn`` 7 | """ 8 | from __future__ import absolute_import 9 | import inspect 10 | import logging 11 | import warnings 12 | 13 | import lore 14 | import lore.estimators 15 | from lore.env import require 16 | from lore.util import timed, before_after_callbacks 17 | 18 | require(lore.dependencies.SKLEARN) 19 | import sklearn 20 | 21 | 22 | class Base(lore.estimators.Base): 23 | def __init__(self, estimator, eval_metric='sklearn_default', scoring_metric='sklearn_default'): 24 | super(Base, self).__init__() 25 | self.sklearn = estimator 26 | self.eval_metric = eval_metric 27 | self.scoring_metric = scoring_metric 28 | 29 | def __setstate__(self, dict): 30 | self.__dict__ = dict 31 | backward_compatible_defaults = { 32 | 'eval_metric': 'sklearn_default', 33 | 'scoring_metric': 'sklearn_default' 34 | } 35 | for key, default in backward_compatible_defaults.items(): 36 | if key not in self.__dict__.keys(): 37 | self.__dict__[key] = default 38 | 39 | @before_after_callbacks 40 | @timed(logging.INFO) 41 | def fit(self, x, y, validation_x=None, validation_y=None, **sklearn_kwargs): 42 | self.sklearn.fit(x, y=y, **sklearn_kwargs) 43 | results = {'eval_metric': self.eval_metric, 44 | 'train': self.evaluate(x, y)} 45 | if validation_x is not None and validation_y is not None: 46 | results['validate'] = self.evaluate(validation_x, validation_y) 47 | return results 48 | 49 | @before_after_callbacks 50 | @timed(logging.INFO) 51 | def predict(self, dataframe): 52 | return self.sklearn.predict(dataframe) 53 | 54 | @before_after_callbacks 55 | @timed(logging.INFO) 56 | def evaluate(self, x, y): 57 | return self.sklearn.score(x, y) 58 | 59 | @before_after_callbacks 60 | @timed(logging.INFO) 61 | def score(self, x, y): 62 | return self.evaluate(x, y) 63 | 64 | 65 | class SKLearn(Base): 66 | def __init__(self, estimator): 67 | frame, filename, line_number, function_name, lines, index = inspect.stack()[1] 68 | warnings.showwarning('Please import SKLearn with "from lore.estimators.sklearn import Base"', 69 | DeprecationWarning, 70 | filename, line_number) 71 | super(SKLearn, self).__init__(estimator) 72 | 73 | 74 | class Regression(Base): 75 | pass 76 | 77 | 78 | class BinaryClassifier(Base): 79 | def __init__(self, estimator): 80 | super(BinaryClassifier, self).__init__(estimator, eval_metric='logloss', scoring_metric='auc') 81 | 82 | @before_after_callbacks 83 | @timed(logging.INFO) 84 | def evaluate(self, x, y): 85 | y_pred = self.predict_proba(x) 86 | return sklearn.metrics.log_loss(y, y_pred) 87 | 88 | @before_after_callbacks 89 | @timed(logging.INFO) 90 | def score(self, x, y): 91 | y_pred = self.predict_proba(x)[:, 1] 92 | return sklearn.metrics.roc_auc_score(y, y_pred) 93 | 94 | @before_after_callbacks 95 | @timed(logging.INFO) 96 | def predict_proba(self, dataframe): 97 | """Predict probabilities using the model 98 | :param dataframe: Dataframe against which to make predictions 99 | """ 100 | return self.sklearn.predict_proba(dataframe) 101 | 102 | 103 | class MutliClassifier(Base): 104 | pass 105 | -------------------------------------------------------------------------------- /lore/estimators/xgboost.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import inspect 3 | import logging 4 | import warnings 5 | import threading 6 | 7 | import lore.env 8 | import lore.estimators 9 | from lore.util import timed, before_after_callbacks 10 | 11 | lore.env.require( 12 | lore.dependencies.XGBOOST + 13 | lore.dependencies.SKLEARN 14 | ) 15 | 16 | import xgboost 17 | 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class Base(object): 23 | def __init__(self, **xgboost_params): 24 | self.eval_metric = xgboost_params.pop('eval_metric', None) 25 | self.scoring_metric = xgboost_params.pop('scoring_metric', None) 26 | self.xgboost_lock = threading.RLock() 27 | self.missing = None 28 | super(Base, self).__init__(**xgboost_params) 29 | 30 | def __getstate__(self): 31 | state = super(Base, self).__getstate__() 32 | state['xgboost_lock'] = None 33 | return state 34 | 35 | def __setstate__(self, state): 36 | self.__dict__ = state 37 | self.xgboost_lock = threading.RLock() 38 | 39 | backward_compatible_defaults = { 40 | 'n_jobs': state.pop('nthread', -1), 41 | 'random_state': state.pop('seed', 0) 42 | } 43 | for key, default in backward_compatible_defaults.items(): 44 | if key not in self.__dict__.keys(): 45 | self.__dict__[key] = default 46 | 47 | @before_after_callbacks 48 | @timed(logging.INFO) 49 | def fit(self, x, y, validation_x=None, validation_y=None, patience=0, verbose=None, **xgboost_kwargs): 50 | eval_set = [(x, y)] 51 | if validation_x is not None and validation_y is not None: 52 | eval_set += [(validation_x, validation_y)] 53 | if verbose is None: 54 | verbose = True if lore.env.NAME == lore.env.DEVELOPMENT else False 55 | try: 56 | super(Base, self).fit( 57 | X=x, 58 | y=y, 59 | eval_set=eval_set, 60 | eval_metric=self.eval_metric, 61 | verbose=verbose, 62 | early_stopping_rounds=patience, 63 | **xgboost_kwargs 64 | ) 65 | except KeyboardInterrupt: 66 | logger.warning('Caught SIGINT. Training aborted.') 67 | 68 | evals = super(Base, self).evals_result() 69 | 70 | if self.scoring_metric is None: 71 | self.scoring_metric = self.eval_metric 72 | 73 | results = { 74 | 'eval_metric': self.eval_metric, 75 | 'train': evals['validation_0'][self.eval_metric][self.best_iteration], 76 | 'best_iteration': self.best_iteration 77 | } 78 | if validation_x is not None: 79 | results['validate'] = evals['validation_1'][self.eval_metric][self.best_iteration] 80 | return results 81 | 82 | @before_after_callbacks 83 | @timed(logging.INFO) 84 | def predict(self, dataframe, ntree_limit=None): 85 | if ntree_limit is None: 86 | ntree_limit = self.best_ntree_limit or 0 87 | with self.xgboost_lock: 88 | return super(Base, self).predict(dataframe, ntree_limit=ntree_limit) 89 | 90 | @before_after_callbacks 91 | @timed(logging.INFO) 92 | def predict_proba(self, dataframe, ntree_limit=None): 93 | if ntree_limit is None: 94 | ntree_limit = self.best_ntree_limit or 0 95 | with self.xgboost_lock: 96 | return super(Base, self).predict_proba(dataframe, ntree_limit=ntree_limit) 97 | 98 | @before_after_callbacks 99 | @timed(logging.INFO) 100 | def evaluate(self, x, y): 101 | with self.xgboost_lock: 102 | return float(self.get_booster().eval(xgboost.DMatrix(x, label=y)).split(':')[-1]) 103 | 104 | @before_after_callbacks 105 | @timed(logging.INFO) 106 | def score(self, x, y): 107 | return self.evaluate(x, y) 108 | 109 | 110 | class XGBoost(lore.estimators.Base): 111 | def __init__(self, **kwargs): 112 | frame, filename, line_number, function_name, lines, index = inspect.stack()[1] 113 | warnings.showwarning('Please import XGBoost with "from lore.estimators.xgboost import Base"', 114 | DeprecationWarning, 115 | filename, line_number) 116 | super(XGBoost, self).__init__(**kwargs) 117 | 118 | 119 | class Regression(Base, xgboost.XGBRegressor): 120 | def __init__( 121 | self, 122 | max_depth=3, 123 | learning_rate=0.1, 124 | n_estimators=100, 125 | silent=True, 126 | objective='reg:linear', 127 | booster='gbtree', 128 | n_jobs=-1, 129 | gamma=0, 130 | min_child_weight=1, 131 | max_delta_step=0, 132 | subsample=1, 133 | colsample_bytree=1, 134 | colsample_bylevel=1, 135 | reg_alpha=0, 136 | reg_lambda=1, 137 | scale_pos_weight=1, 138 | base_score=0.5, 139 | random_state=0, 140 | missing=None, 141 | eval_metric='rmse', 142 | **kwargs 143 | ): 144 | kwargs = locals() 145 | kwargs.pop('self') 146 | kwargs.pop('__class__', None) 147 | kwargs = dict(kwargs, **(kwargs.pop('kwargs', {}))) 148 | if 'random_state' not in kwargs and 'seed' in kwargs: 149 | kwargs['random_state'] = kwargs.pop('seed') 150 | if 'n_jobs' not in kwargs and 'nthread' in kwargs: 151 | kwargs['n_jobs'] = kwargs.pop('nthread') 152 | super(Regression, self).__init__(**kwargs) 153 | 154 | 155 | class BinaryClassifier(Base, xgboost.XGBClassifier): 156 | def __init__( 157 | self, 158 | max_depth=3, 159 | learning_rate=0.1, 160 | n_estimators=100, 161 | silent=True, 162 | objective='binary:logistic', 163 | booster='gbtree', 164 | n_jobs=-1, 165 | gamma=0, 166 | min_child_weight=1, 167 | max_delta_step=0, 168 | subsample=1, 169 | colsample_bytree=1, 170 | colsample_bylevel=1, 171 | reg_alpha=0, 172 | reg_lambda=1, 173 | scale_pos_weight=1, 174 | base_score=0.5, 175 | random_state=0, 176 | missing=None, 177 | eval_metric='logloss', 178 | scoring_metric='auc', 179 | **kwargs 180 | ): 181 | kwargs = locals() 182 | kwargs.pop('self') 183 | kwargs.pop('__class__', None) 184 | kwargs = dict(kwargs, **(kwargs.pop('kwargs', {}))) 185 | if 'random_state' not in kwargs and 'seed' in kwargs: 186 | kwargs['random_state'] = kwargs.pop('seed') 187 | if 'n_jobs' not in kwargs and 'nthread' in kwargs: 188 | kwargs['n_jobs'] = kwargs.pop('nthread') 189 | super(BinaryClassifier, self).__init__(**kwargs) 190 | 191 | @before_after_callbacks 192 | @timed(logging.INFO) 193 | def score(self, x, y): 194 | import sklearn 195 | y_pred = self.predict_proba(x)[:, 1] 196 | return sklearn.metrics.roc_auc_score(y, y_pred) 197 | 198 | 199 | MutliClassifier = BinaryClassifier 200 | -------------------------------------------------------------------------------- /lore/features/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instacart/lore/a14f65a96d0ea2513a35e424b4e16d948115b89c/lore/features/__init__.py -------------------------------------------------------------------------------- /lore/features/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import datetime 3 | from abc import ABCMeta, abstractmethod 4 | 5 | import lore 6 | from lore.env import require 7 | from lore.util import convert_df_columns_to_json 8 | 9 | require( 10 | lore.dependencies.PANDAS + 11 | lore.dependencies.INFLECTION 12 | ) 13 | 14 | import pandas 15 | import inflection 16 | 17 | 18 | class BaseFeatureExporter(object): 19 | __metaclass__ = ABCMeta 20 | 21 | def __init__(self, collection_ts=datetime.datetime.now()): 22 | self.collection_ts = collection_ts 23 | 24 | @property 25 | def key(self): 26 | """ 27 | :return: Composite or a single key for index 28 | """ 29 | raise NotImplementedError 30 | 31 | @property 32 | def timestamp(self): 33 | return datetime.datetime.combine(self.collection_ts.date(), 34 | datetime.datetime.min.time()) 35 | 36 | @abstractmethod 37 | def get_data(self): 38 | pass 39 | 40 | @abstractmethod 41 | def publish(self): 42 | """ 43 | Publish the feature to store ( S3, Memcache, Redis, Cassandra etc) 44 | :return: None 45 | """ 46 | pass 47 | 48 | @property 49 | def _raw_data(self): 50 | return self.get_data() 51 | 52 | @property 53 | def version(self, version=str(datetime.date.today())): 54 | """ 55 | Feature version : Override this method if you want to manage versions yourself 56 | ex 'v1', 'v2' 57 | By default will you date as the version information 58 | :param version: 59 | :return: 60 | """ 61 | return 'v1' 62 | 63 | @property 64 | def name(self): 65 | return inflection.underscore(self._value) 66 | 67 | @property 68 | def _values(self): 69 | value_cols = set(self._raw_data.columns.values.tolist()) - set(self.key) 70 | if len(value_cols) > 1: 71 | raise ValueError('Only one feature column allowed') 72 | return list(value_cols) 73 | 74 | @property 75 | def _value(self): 76 | return self._values[0] 77 | 78 | def _features_as_kv(self): 79 | """ 80 | Return features row as kv pairs so that they can be stored in memcache or redis and 81 | used at serving layer 82 | :return: a nested hash for each column 83 | """ 84 | self._data = self.get_data() 85 | key_list = self.key() 86 | values_list = self.values() 87 | result = {} 88 | for column in values_list: 89 | key_prefix = self.cache_key_prefix() + "#" + column 90 | self._data['cache_key'] = self._data[key_list].apply(lambda xdf: key_prefix + "=" + '#'.join(xdf.astype(str).values), axis=1) 91 | result[column] = dict(zip(self._data.cache_key.values, self._data[column].values)) 92 | return result 93 | 94 | @property 95 | def cache_key_prefix(self): 96 | return ('#').join(self.key) 97 | 98 | def _generate_row_keys(self, df): 99 | """ 100 | Method to generate rows keys for storage in the DB 101 | :param df: DataFrame to generate rows keys forecast 102 | 103 | This method will use the key definition initially provided 104 | and convert those columns into a JSON column 105 | :return: 106 | """ 107 | keys = self.key 108 | return convert_df_columns_to_json(df, keys) 109 | 110 | def _generate_row_keys_for_serving(self, df): 111 | """ 112 | Method for generating key features at serving time or prediction time 113 | :param data: Pass in the data that is necessary for generating the keys 114 | Example : 115 | Feature : User warehouse searches and conversions 116 | Keys will be of the form 'user_id#warehouse_id#searches=23811676#3' 117 | Keys will be of the form 'user_id#warehouse_id#conversions=23811676#3' 118 | data Frame should have values for all the columns as feature_key in this case ['user_id','warehouse_id'] 119 | :return: 120 | """ 121 | keys = self.key 122 | key_prefix = self.cache_key_prefix 123 | cache_keys = df[keys].apply(lambda xdf: key_prefix + "=" + '#'.join(xdf.astype(str).values), 124 | axis=1) 125 | return list(cache_keys) 126 | 127 | def __repr__(self): 128 | return ( 129 | """ 130 | Version : {} 131 | Name : {} 132 | Keys : {} 133 | Rows : {} 134 | """.format(self.version, self.name, self.key, len(self._data)) 135 | ) 136 | 137 | def metadata(self): 138 | return { 139 | "version": self.version, 140 | "name": self.name, 141 | "keys": self.key, 142 | "num_rows": len(self._data) 143 | } 144 | 145 | def distribute(self, cache): 146 | """ 147 | Sync features to a key value compliant cache. Should adhere to cache protocol 148 | :param cache: 149 | :return: None 150 | """ 151 | data = self._features_as_kv() 152 | for key in data.keys(): 153 | cache.batch_set(data[key]) 154 | 155 | 156 | class BaseFeatureImporter(object): 157 | def __init__(self, entity_name, feature_name, version, start_date, end_date): 158 | self.entity_name = entity_name 159 | self.feature_name = feature_name 160 | self.version = version 161 | self.start_date = start_date 162 | self.end_date = end_date 163 | 164 | @property 165 | def feature_data(self): 166 | raise NotImplementedError 167 | 168 | 169 | -------------------------------------------------------------------------------- /lore/features/db.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from abc import ABCMeta 4 | from datetime import datetime 5 | import lore.io 6 | from lore.features.base import BaseFeatureExporter, BaseFeatureImporter 7 | import lore.metadata 8 | from lore.metadata import FeatureMetaData 9 | from sqlalchemy.orm import sessionmaker, scoped_session 10 | import pandas 11 | from lore.util import convert_df_columns_to_json 12 | import json 13 | 14 | engine = lore.io.metadata._engine 15 | Session = scoped_session(sessionmaker(bind=engine)) 16 | 17 | 18 | class DBFeatureExporter(BaseFeatureExporter): 19 | __metaclass__ = ABCMeta 20 | 21 | @property 22 | def entity_name(self): 23 | raise NotImplementedError 24 | 25 | @property 26 | def _data(self): 27 | df = self._raw_data 28 | df['key'] = self._generate_row_keys(df) 29 | df['created_at'] = datetime.utcnow() 30 | df['feature_data'] = df[self._values] 31 | df.drop(self._values, inplace=True, axis=1) 32 | df.drop(self.key, inplace=True, axis=1) 33 | return df 34 | 35 | @property 36 | def dtypes(self): 37 | df = self._raw_data 38 | dtypes = (df[self.key + self._values] 39 | .dtypes 40 | .to_frame('dtype')) 41 | dtypes = dtypes['dtype'].astype(str).to_dict() 42 | return dtypes 43 | 44 | def publish(self): 45 | df = self._data 46 | feature_metadata = lore.metadata.FeatureMetaData.create(created_at=datetime.utcnow(), 47 | entity_name=self.entity_name, 48 | feature_name=self.name, 49 | version=self.version, 50 | snapshot_at=self.timestamp, 51 | feature_dtypes=self.dtypes, 52 | s3_url=None) 53 | df['feature_metadata_id'] = feature_metadata.id 54 | lore.io.metadata.insert('features', df) 55 | 56 | 57 | class DBFeatureImporter(BaseFeatureImporter): 58 | @property 59 | def feature_data(self): 60 | session = Session() 61 | metadata = (session.query(FeatureMetaData) 62 | .filter_by(entity_name=self.entity_name, 63 | feature_name=self.feature_name, 64 | version=self.version, 65 | s3_url=None) 66 | .filter(FeatureMetaData.snapshot_at.between(self.start_date, self.end_date))) 67 | if metadata.count() == 0: 68 | return pandas.DataFrame() 69 | 70 | metadata_ids = [str(m.id) for m in metadata] 71 | feature_name = metadata[0].feature_name 72 | dtypes = metadata[0].feature_dtypes 73 | sql = """SELECT key, feature_data FROM features where feature_metadata_id in ({feature_metadata_ids})""" 74 | sql = sql.format(feature_metadata_ids=','.join(metadata_ids)) 75 | df = lore.io.metadata.dataframe(sql, feature_metadata_ids=metadata_ids) 76 | if lore.io.metadata.adapter == 'sqlite': 77 | key_df = pandas.io.json.json_normalize(df.key.apply(json.loads)) 78 | else: 79 | key_df = pandas.io.json.json_normalize(df.key) 80 | df.drop('key', axis=1, inplace=True) 81 | df = pandas.concat([df, key_df], axis=1) 82 | df = df.rename({'feature_data': feature_name}, axis=1) 83 | df = df.astype(dtypes) 84 | return df 85 | -------------------------------------------------------------------------------- /lore/features/s3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from abc import ABCMeta, abstractmethod 4 | import json 5 | import os 6 | import tempfile 7 | 8 | import lore 9 | from lore.features.base import BaseFeatureExporter 10 | from lore.io import upload 11 | 12 | 13 | class S3FeatureExporter(BaseFeatureExporter): 14 | __metaclass__ = ABCMeta 15 | 16 | @abstractmethod 17 | def serialization(self): 18 | pass 19 | 20 | @property 21 | def _data(self): 22 | df = self.get_data() 23 | return df 24 | 25 | def publish(self, compression='gzip'): 26 | temp_file, temp_path = tempfile.mkstemp(dir=lore.env.DATA_DIR) 27 | data = self._data 28 | 29 | if self.serialization() == 'csv': 30 | data.to_csv(temp_path, index=False, compression=compression) 31 | elif self.serialization() == 'pickle': 32 | data.to_pickle(temp_path, compression=compression) 33 | else: 34 | raise "Invalid serialization" 35 | upload(temp_path, self.data_path()) 36 | 37 | with open(temp_path, 'w') as f: 38 | f.write(json.dumps(self.metadata())) 39 | upload(temp_path, self.metadata_path()) 40 | os.close(temp_file) 41 | os.remove(temp_path) 42 | 43 | def data_path(self): 44 | return "{}/{}/data.{}".format(self.version, self.name, self.serialization()) 45 | 46 | def metadata_path(self): 47 | return "{}/{}/metadata.json".format(self.version, self.name) 48 | -------------------------------------------------------------------------------- /lore/io/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import re 5 | import shutil 6 | import tarfile 7 | import tempfile 8 | try: 9 | from urllib.parse import urlparse 10 | except ImportError: 11 | from urlparse import urlparse 12 | 13 | import lore 14 | from lore.env import require, configparser 15 | from lore.util import timer 16 | from lore.io.connection import Connection 17 | from lore.io.multi_connection_proxy import MultiConnectionProxy 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | config = lore.env.DATABASE_CONFIG 24 | if config: 25 | try: 26 | for database, url in config.items('DATABASES'): 27 | vars()[database] = Connection(url=url, name=database) 28 | except configparser.NoSectionError: 29 | pass 30 | 31 | for section in config.sections(): 32 | if section == 'DATABASES': 33 | continue 34 | 35 | options = config._sections[section] 36 | if options.get('url') == '$DATABASE_URL': 37 | logger.error('$DATABASE_URL is not set, but is used in config/database.cfg. Skipping connection.') 38 | else: 39 | if 'urls' in options: 40 | vars()[section.lower()] = MultiConnectionProxy(name=section.lower(), **options) 41 | else: 42 | vars()[section.lower()] = Connection(name=section.lower(), **options) 43 | 44 | 45 | if 'metadata' not in vars(): 46 | vars()['metadata'] = Connection('sqlite:///%s/metadata.sqlite' % lore.env.DATA_DIR) 47 | 48 | redis_config = lore.env.REDIS_CONFIG 49 | if redis_config: 50 | require(lore.dependencies.REDIS) 51 | import redis 52 | 53 | for section in redis_config.sections(): 54 | vars()[section.lower()] = redis.StrictRedis(host=redis_config.get(section, 'url'), 55 | port=redis_config.get(section, 'port')) 56 | 57 | s3 = None 58 | bucket = None 59 | if lore.env.AWS_CONFIG: 60 | require(lore.dependencies.S3) 61 | import boto3 62 | from botocore.exceptions import ClientError 63 | 64 | config = lore.env.AWS_CONFIG 65 | if config and 'ACCESS_KEY' in config.sections(): 66 | s3 = boto3.resource( 67 | 's3', 68 | aws_access_key_id=config.get('ACCESS_KEY', 'id'), 69 | aws_secret_access_key=config.get('ACCESS_KEY', 'secret') 70 | ) 71 | else: 72 | s3 = boto3.resource('s3') 73 | 74 | if s3 and config and 'BUCKET' in config.sections(): 75 | bucket = s3.Bucket(config.get('BUCKET', 'name')) 76 | 77 | 78 | def download(remote_url, local_path=None, cache=True, extract=False): 79 | _bucket = bucket 80 | if re.match(r'^https?://', remote_url): 81 | protocol = 'http' 82 | elif re.match(r'^s3?://', remote_url): 83 | require(lore.dependencies.S3) 84 | import boto3 85 | from botocore.exceptions import ClientError 86 | protocol = 's3' 87 | url_parts = urlparse(remote_url) 88 | remote_url = url_parts.path[1:] 89 | _bucket = boto3.resource('s3').Bucket(url_parts.netloc) 90 | else: 91 | if s3 is None or bucket is None: 92 | raise NotImplementedError("Cannot download from s3 without config/aws.cfg") 93 | protocol = 's3' 94 | remote_url = prefix_remote_root(remote_url) 95 | if cache: 96 | if local_path is None: 97 | if protocol == 'http': 98 | filename = lore.env.parse_url(remote_url).path.split('/')[-1] 99 | elif protocol == 's3': 100 | filename = remote_url 101 | local_path = os.path.join(lore.env.DATA_DIR, filename) 102 | 103 | if os.path.exists(local_path): 104 | return local_path 105 | elif local_path: 106 | raise ValueError("You can't pass lore.io.download(local_path=X), unless you also pass cache=True") 107 | elif extract: 108 | raise ValueError("You can't pass lore.io.download(extract=True), unless you also pass cache=True") 109 | 110 | with timer('DOWNLOAD: %s' % remote_url): 111 | temp_file, temp_path = tempfile.mkstemp(dir=lore.env.WORK_DIR) 112 | os.close(temp_file) 113 | try: 114 | if protocol == 'http': 115 | lore.env.retrieve_url(remote_url, temp_path) 116 | else: 117 | _bucket.download_file(remote_url, temp_path) 118 | except ClientError as e: 119 | logger.error("Error downloading file: %s" % e) 120 | raise 121 | 122 | if cache: 123 | dir = os.path.dirname(local_path) 124 | if not os.path.exists(dir): 125 | try: 126 | os.makedirs(dir) 127 | except os.FileExistsError: 128 | pass # race to create 129 | 130 | shutil.copy(temp_path, local_path) 131 | os.remove(temp_path) 132 | 133 | if extract: 134 | with timer('EXTRACT: %s' % local_path, logging.WARNING): 135 | if local_path[-7:] == '.tar.gz': 136 | with tarfile.open(local_path, 'r:gz') as tar: 137 | tar.extractall(os.path.dirname(local_path)) 138 | elif local_path[-4:] == '.zip': 139 | import zipfile 140 | with zipfile.ZipFile(local_path, 'r') as zip: 141 | zip.extractall(os.path.dirname(local_path)) 142 | 143 | else: 144 | local_path = temp_path 145 | return local_path 146 | 147 | 148 | # Note: This can be rewritten in a more efficient way 149 | # https://stackoverflow.com/questions/11426560/amazon-s3-boto-how-to-delete-folder 150 | def delete_folder(remote_url): 151 | if remote_url is None: 152 | raise ValueError("remote_url cannot be None") 153 | else: 154 | remote_url = prefix_remote_root(remote_url) 155 | if not remote_url.endswith('/'): 156 | remote_url = remote_url + '/' 157 | keys = bucket.objects.filter(Prefix=remote_url) 158 | empty = True 159 | 160 | for key in keys: 161 | empty = False 162 | key.delete() 163 | 164 | if empty: 165 | logger.info('Remote was not a folder') 166 | 167 | 168 | def delete(remote_url, recursive=False): 169 | if s3 is None: 170 | raise NotImplementedError("Cannot delete from s3 without config/aws.cfg") 171 | 172 | if remote_url is None: 173 | raise ValueError("remote_url cannot be None") 174 | 175 | if (recursive is False) and (remote_url.endswith('/')): 176 | raise ValueError("remote_url cannot end with trailing / when recursive is False") 177 | 178 | remote_url = prefix_remote_root(remote_url) 179 | if recursive is True: 180 | delete_folder(remote_url) 181 | else: 182 | obj = bucket.Object(key=remote_url) 183 | obj.delete() 184 | 185 | 186 | def upload_object(obj, remote_path=None): 187 | if remote_path is None: 188 | raise ValueError("remote_path cannot be None when uploading objects") 189 | else: 190 | with tempfile.NamedTemporaryFile(delete=False) as f: 191 | pickle.dump(obj, f) 192 | upload_file(f.name, remote_path) 193 | os.remove(f.name) 194 | 195 | 196 | def upload_file(local_path, remote_path=None): 197 | if s3 is None: 198 | raise NotImplementedError("Cannot upload to s3 without config/aws.cfg") 199 | 200 | if remote_path is None: 201 | remote_path = remote_from_local(local_path) 202 | remote_path = prefix_remote_root(remote_path) 203 | 204 | with timer('UPLOAD: %s -> %s' % (local_path, remote_path)): 205 | try: 206 | bucket.upload_file(local_path, remote_path, ExtraArgs={'ServerSideEncryption': 'AES256'}) 207 | except ClientError as e: 208 | logger.error("Error uploading file: %s" % e) 209 | raise 210 | 211 | 212 | def upload(obj, remote_path=None): 213 | if isinstance(obj, str): 214 | local_path = obj 215 | upload_file(local_path, remote_path) 216 | else: 217 | upload_object(obj, remote_path) 218 | 219 | 220 | def remote_from_local(local_path): 221 | return re.sub( 222 | r'^%s' % re.escape(lore.env.WORK_DIR), 223 | '', 224 | local_path 225 | ) 226 | 227 | 228 | def prefix_remote_root(path): 229 | if path.startswith('/'): 230 | path = path[1:] 231 | 232 | if not path.startswith(lore.env.NAME + '/'): 233 | path = os.path.join(lore.env.NAME, path) 234 | 235 | return path 236 | -------------------------------------------------------------------------------- /lore/io/multi_connection_proxy.py: -------------------------------------------------------------------------------- 1 | import random 2 | import logging 3 | import re 4 | 5 | from lore.io.connection import Connection 6 | from lore.util import scrub_url 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class MultiConnectionProxy(object): 12 | 13 | SQL_RUNNING_METHODS = ['dataframe', 'unload', 'select', 'execute', 'temp_table'] 14 | 15 | def __init__(self, urls, name='connection', watermark=True, **kwargs): 16 | sticky = kwargs.pop('sticky_connection', None) 17 | sticky = False if sticky is None else (sticky.lower() == 'true') 18 | 19 | self._urls = urls 20 | self._sticky = sticky 21 | self._connections = [] 22 | self._active_connection = None 23 | 24 | self.parse_connections(name, watermark, **kwargs) 25 | 26 | def parse_connections(self, name, watermark, **kwargs): 27 | kwargs.pop('url', None) 28 | for url in re.split(r'\s+', self._urls): 29 | c = Connection(url, name=name, watermark=watermark, **kwargs) 30 | self._connections.append(c) 31 | self.shuffle_connections() 32 | 33 | def shuffle_connections(self): 34 | if len(self._connections) == 0: 35 | return 36 | if len(self._connections) == 1: 37 | self._active_connection = self._connections[0] 38 | else: 39 | filtered = list(filter(lambda x: x is not self._active_connection, self._connections)) 40 | self._active_connection = filtered[0] if len(filtered) == 1 else random.choice(filtered) 41 | self.log_connection() 42 | 43 | def log_connection(self): 44 | logger.debug("using database connection {}".format(scrub_url(self._active_connection.url))) 45 | 46 | # proxying - forward getattr to self._active_connection if not defined in MultiConnectionProxy 47 | 48 | def __getattr__(self, attr): 49 | if not self._sticky and attr in self.SQL_RUNNING_METHODS: 50 | self.shuffle_connections() 51 | return getattr(self._active_connection, attr) 52 | -------------------------------------------------------------------------------- /lore/metadata/__init__.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import inflection 3 | import subprocess 4 | import logging 5 | import os 6 | 7 | from sqlalchemy import Column, Float, Integer, String, DateTime, \ 8 | JSON, func, ForeignKey, Index, UniqueConstraint 9 | from sqlalchemy.dialects.postgresql import JSONB 10 | from sqlalchemy.ext.declarative import declarative_base, declared_attr 11 | from sqlalchemy.orm import sessionmaker, relationship, scoped_session 12 | from sqlalchemy import TypeDecorator, types, desc 13 | from sqlalchemy.inspection import inspect 14 | import lore.io 15 | import json 16 | 17 | logger = logging.getLogger(__name__) 18 | Base = declarative_base() 19 | adapter = lore.io.metadata.adapter 20 | engine = lore.io.metadata._engine 21 | Session = scoped_session(sessionmaker(bind=engine)) 22 | 23 | 24 | if adapter == 'sqlite': 25 | # JSON support is not available in SQLite 26 | class StringJSON(TypeDecorator): 27 | @property 28 | def python_type(self): 29 | return object 30 | 31 | impl = types.String 32 | 33 | def process_bind_param(self, value, dialect): 34 | return json.dumps(value) 35 | 36 | def process_literal_param(self, value, dialect): 37 | return value 38 | 39 | def process_result_value(self, value, dialect): 40 | try: 41 | return json.loads(value) 42 | except (ValueError, TypeError): 43 | return None 44 | JSON = StringJSON 45 | JSONB = StringJSON 46 | 47 | # Commenting sqlite queries with the SQLAlchemy declarative_base API 48 | # is broken: https://github.com/sqlalchemy/sqlalchemy/issues/4396 49 | engine.dialect.supports_sane_rowcount = False 50 | engine.dialect.supports_sane_multi_rowcount = False # for executemany() 51 | 52 | 53 | class Crud(object): 54 | query = Session.query_property() 55 | 56 | @declared_attr 57 | def __tablename__(cls): 58 | return inflection.pluralize(inflection.underscore(cls.__name__)) 59 | 60 | def __repr__(self): 61 | properties = ['%s=%s' % (key, value) for key, value in self.__dict__.items() if key[0] != '_'] 62 | return '<%s(%s)>' % (self.__class__.__name__, ', '.join(properties)) 63 | 64 | @classmethod 65 | def create(cls, **kwargs): 66 | self = cls(**kwargs) 67 | self.save() 68 | return self 69 | 70 | @classmethod 71 | def get(cls, *key): 72 | session = Session() 73 | 74 | filter = {str(k.name): v for k, v in dict(zip(inspect(cls).primary_key, key)).items()} 75 | instance = session.query(cls).filter_by(**filter).first() 76 | session.close() 77 | return instance 78 | 79 | @classmethod 80 | def get_or_create(cls, **kwargs): 81 | ''' 82 | Creates an object or returns the object if exists 83 | credit to Kevin @ StackOverflow 84 | from: http://stackoverflow.com/questions/2546207/does-sqlalchemy-have-an-equivalent-of-djangos-get-or-create 85 | ''' 86 | session = Session() 87 | instance = session.query(cls).filter_by(**kwargs).first() 88 | session.close() 89 | 90 | if not instance: 91 | self = cls(**kwargs) 92 | self.save() 93 | else: 94 | self = instance 95 | 96 | return self 97 | 98 | @classmethod 99 | def all(cls, order_by=None, limit=None, **filters): 100 | session = Session() 101 | query = session.query(cls) 102 | if filters: 103 | query = query.filter_by(**filters) 104 | if isinstance(order_by, list) or isinstance(order_by, tuple): 105 | query = query.order_by(*order_by) 106 | elif order_by is not None: 107 | query = query.order_by(order_by) 108 | if limit: 109 | query = query.limit(limit) 110 | result = query.all() 111 | session.close() 112 | return result 113 | 114 | @classmethod 115 | def last(cls, order_by=None, limit=1, **filters): 116 | if order_by is None: 117 | order_by = inspect(cls).primary_key 118 | if isinstance(order_by, list) or isinstance(order_by, tuple): 119 | order_by = desc(*order_by) 120 | else: 121 | order_by = desc(order_by) 122 | return cls.first(order_by=order_by, limit=limit, **filters) 123 | 124 | @classmethod 125 | def first(cls, order_by=None, limit=1, **filters): 126 | if order_by is None: 127 | order_by = inspect(cls).primary_key 128 | result = cls.all(order_by=order_by, limit=limit, **filters) 129 | 130 | if limit == 1: 131 | if len(result) == 0: 132 | result = None 133 | else: 134 | result = result[0] 135 | 136 | return result 137 | 138 | def save(self): 139 | session = Session() 140 | session.add(self) 141 | try: 142 | return session.commit() 143 | except Exception as ex: 144 | session.rollback() 145 | raise 146 | 147 | def update(self, **kwargs): 148 | for key, value in kwargs.items(): 149 | self.__dict__[key] = value 150 | return self.save() 151 | 152 | def delete(self): 153 | session = Session() 154 | session.delete(self) 155 | try: 156 | return session.commit() 157 | except Exception as ex: 158 | session.rollback() 159 | raise 160 | 161 | 162 | class Commit(Crud, Base): 163 | sha = Column(String, primary_key=True) 164 | created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) 165 | updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now) 166 | message = Column(String) 167 | author_name = Column(String, index=True) 168 | author_email = Column(String) 169 | fittings = relationship('Fitting', back_populates='commit') 170 | snapshots = relationship('Snapshot', back_populates='commit') 171 | 172 | @classmethod 173 | def from_git(cls, sha='HEAD'): 174 | process = subprocess.Popen([ 175 | 'git', 176 | 'rev-list', 177 | '--format=NAME: %an%nEMAIL: %aE%nDATE: %at%nMESSAGE:%n%B', 178 | '--max-count=1', 179 | sha, 180 | ], stdout=subprocess.PIPE) 181 | out, err = process.communicate() 182 | 183 | # If there is no Git repo, exit code will be non-zero 184 | if process.returncode == 0: 185 | lines = out.strip().decode().split(os.linesep) 186 | 187 | check, sha = lines[0].split('commit ') 188 | if check or not sha: 189 | logger.error('bad git parse: %s' % out) 190 | 191 | check, author_name = lines[1].split('NAME: ') 192 | if check or not author_name: 193 | logger.error('bad git parse for NAME: %s' % out) 194 | 195 | check, author_email = lines[2].split('EMAIL: ') 196 | if check or not author_email: 197 | logger.error('bad git parse for EMAIL: %s' % out) 198 | 199 | check, date = lines[3].split('DATE: ') 200 | if check or not date: 201 | logger.error('bad git parse for DATE: %s' % out) 202 | created_at = datetime.datetime.fromtimestamp(int(date)) 203 | 204 | check, message = lines[4], os.linesep.join(lines[5:]) 205 | if check != 'MESSAGE:' or not message: 206 | logger.error('bad git parse for MESSAGE: %s' % out) 207 | 208 | return Commit.get_or_create( 209 | sha=sha, 210 | author_name=author_name, 211 | author_email=author_email, 212 | created_at=created_at, 213 | message=message 214 | ) 215 | else: 216 | return None 217 | 218 | 219 | class Snapshot(Crud, Base): 220 | """ 221 | Metadata summary description of each column in the snapshot 222 | 223 | """ 224 | id = Column(Integer, primary_key=True) 225 | created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) 226 | completed_at = Column(DateTime) 227 | pipeline = Column(String, index=True) 228 | cache = Column(String) 229 | args = Column(String) 230 | commit_sha = Column(String, ForeignKey('commits.sha'), index=True) 231 | # samples = Column(Integer) 232 | bytes = Column(Integer) 233 | head = Column(String) 234 | tail = Column(String) 235 | stats = Column(String) 236 | encoders = Column(JSON) 237 | 238 | description = Column(String) 239 | fittings = relationship('Fitting', back_populates='snapshot') 240 | commit = relationship('Commit', back_populates='snapshots') 241 | 242 | 243 | class Fitting(Crud, Base): 244 | id = Column(Integer, primary_key=True) 245 | commit_sha = Column(String, ForeignKey('commits.sha')) 246 | created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) 247 | completed_at = Column(DateTime) 248 | snapshot_id = Column(Integer, ForeignKey('snapshots.id'), nullable=False, index=True) 249 | scoring_metric = Column(String) 250 | score = Column(Float) 251 | model = Column(String, index=True) 252 | args = Column(JSON) 253 | stats = Column(JSON) 254 | custom_data = Column(JSON) 255 | url = Column(String) 256 | uploaded_at = Column(DateTime) 257 | 258 | commit = relationship('Commit', back_populates='fittings') 259 | predictions = relationship('Prediction', back_populates='fitting') 260 | snapshot = relationship('Snapshot', back_populates='fittings') 261 | 262 | def __init__(self, **kwargs): 263 | if 'commit' not in kwargs: 264 | self.commit = Commit.from_git() 265 | super(Fitting, self).__init__(**kwargs) 266 | 267 | 268 | class Prediction(Crud, Base): 269 | id = Column(Integer, primary_key=True) 270 | fitting_id = Column(Integer, ForeignKey('fittings.id'), nullable=False, index=True) 271 | created_at = Column(DateTime, default=datetime.datetime.now) 272 | value = Column(JSON) 273 | key = Column(JSON) 274 | features = Column(JSON) 275 | custom_data = Column(JSON) 276 | 277 | fitting = relationship('Fitting', back_populates='predictions') 278 | 279 | 280 | class FeatureMetaData(Crud, Base): 281 | __tablename__ = 'feature_metadata' 282 | __table_args__ = ( 283 | UniqueConstraint('entity_name', 'feature_name', 'snapshot_at', name='unique_entity_feature_ts'), ) 284 | id = Column(Integer, primary_key=True) 285 | created_at = Column(DateTime, default=datetime.datetime.now) 286 | entity_name = Column(String, nullable=False) 287 | feature_name = Column(String, nullable=False) 288 | feature_dtypes = Column(JSON) 289 | version = Column(String, nullable=False) 290 | snapshot_at = Column(DateTime) 291 | s3_url = Column(String) 292 | 293 | feature_data = relationship('Feature', back_populates='feature_metadata') 294 | 295 | 296 | class Feature(Crud, Base): 297 | __table_args__ = ( 298 | Index('feature_metadata_id', 'key', unique=True),) 299 | 300 | id = Column(Integer, primary_key=True) 301 | feature_metadata_id = Column(Integer, ForeignKey('feature_metadata.id'), nullable=False) 302 | created_at = Column(DateTime, default=datetime.datetime.now) 303 | key = Column(JSONB, nullable=False) 304 | feature_data = Column(String) 305 | 306 | feature_metadata = relationship('FeatureMetaData', back_populates='feature_data') 307 | 308 | 309 | Base.metadata.create_all(engine) 310 | -------------------------------------------------------------------------------- /lore/models/__init__.py: -------------------------------------------------------------------------------- 1 | from lore.models import base 2 | -------------------------------------------------------------------------------- /lore/models/keras.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join, dirname 3 | import logging 4 | import botocore 5 | 6 | import lore 7 | from lore.util import timer 8 | from lore.env import require 9 | 10 | require(lore.dependencies.H5PY) 11 | import h5py 12 | 13 | 14 | try: 15 | FileExistsError 16 | except NameError: 17 | FileExistsError = OSError 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class Base(lore.models.base.Base): 24 | def __init__(self, pipeline, estimator): 25 | super(Base, self).__init__(pipeline, estimator) 26 | 27 | def weights_path(self): 28 | return join(self.fitting_path(), 'weights.h5') 29 | 30 | def checkpoint_path(self): 31 | return join(self.fitting_path(), 'checkpoints/{epoch}.h5') 32 | 33 | def tensorboard_path(self): 34 | return join(self.fitting_path(), 'tensorboard') 35 | 36 | def timeline_path(self): 37 | return join(self.fitting_path(), 'timeline.json') 38 | 39 | def remote_weights_path(self): 40 | if self.fitting: 41 | return join(self.remote_path(), str(self.fitting.id), 'weights.h5') 42 | else: 43 | return join(self.remote_path(), 'weights.h5') 44 | 45 | @property 46 | def fitting(self): 47 | return self._fitting 48 | 49 | @fitting.setter 50 | def fitting(self, value): 51 | self._fitting = value 52 | if self._fitting is not None: 53 | 54 | if not os.path.exists(dirname(self.checkpoint_path())): 55 | try: 56 | os.makedirs(dirname(self.checkpoint_path())) 57 | except FileExistsError as ex: 58 | pass # race to create 59 | 60 | if not os.path.exists(dirname(self.tensorboard_path())): 61 | try: 62 | os.makedirs(dirname(self.tensorboard_path())) 63 | except FileExistsError as ex: 64 | pass # race to create 65 | 66 | def save(self, stats=None): 67 | super(Base, self).save(stats) 68 | 69 | with timer('save weights'): 70 | # Only save weights, because saving named layers that have shared 71 | # weights causes an error on reload 72 | self.estimator.keras.save_weights(self.weights_path()) 73 | 74 | # Patch for keras 2 models saved with optimizer weights: 75 | # https://github.com/gagnonlg/explore-ml/commit/c05b01076c7eb99dae6a480a05ac14765ef08e4b 76 | with h5py.File(self.weights_path(), 'a') as f: 77 | if 'optimizer_weights' in f.keys(): 78 | del f['optimizer_weights'] 79 | 80 | @classmethod 81 | def load(cls, fitting_id=None): 82 | model = super(Base, cls).load(fitting_id) 83 | 84 | if hasattr(model, 'estimator'): 85 | # HACK to set estimator model, and model serializer 86 | model.estimator = model.estimator 87 | 88 | # Rely on build + load_weights rather than loading the named layers 89 | # w/ Keras for efficiency (and also because it causes a 90 | # deserialization issue) as of Keras 2.0.4: 91 | # https://github.com/fchollet/keras/issues/5442 92 | model.estimator.build() 93 | 94 | try: 95 | with timer('load weights %i' % model.fitting.id): 96 | model.estimator.keras.load_weights(model.weights_path()) 97 | except ValueError as ex: 98 | if model.estimator.multi_gpu_model and not lore.estimators.keras.available_gpus: 99 | error = "Trying to load a multi_gpu_model when no GPUs are present is not supported" 100 | logger.exception(error) 101 | raise NotImplementedError(error) 102 | else: 103 | raise 104 | 105 | else: 106 | model.build() 107 | with timer('load weights'): 108 | model.keras.load_weights(model.weights_path()) 109 | 110 | return model 111 | 112 | def upload(self): 113 | super(Base, self).upload() 114 | lore.io.upload(self.weights_path(), self.remote_weights_path()) 115 | 116 | @classmethod 117 | def download(cls, fitting_id=0): 118 | model = cls(None, None) 119 | if fitting_id is None: 120 | model.fitting = model.last_fitting() 121 | else: 122 | model.fitting = lore.metadata.Fitting.get(fitting_id) 123 | 124 | if model.fitting is None: 125 | logger.warning("Attempting to download a model from outside of the metadata store is deprecated and will be removed in 0.8.0") 126 | model.fitting = lore.metadata.Fitting(id=fitting_id) 127 | 128 | try: 129 | lore.io.download(model.remote_weights_path(), model.weights_path()) 130 | except botocore.exceptions.ClientError as e: 131 | if e.response['Error']['Code'] == "404": 132 | model.fitting.id = None 133 | logger.warning("Attempting to download a model without a fitting id is deprecated and will be removed in 0.8.0") 134 | lore.io.download(model.remote_weights_path(), model.weights_path()) 135 | return super(Base, cls).download(model.fitting.id) 136 | -------------------------------------------------------------------------------- /lore/models/naive.py: -------------------------------------------------------------------------------- 1 | import lore.models.base 2 | 3 | 4 | class Base(lore.models.base.Base): 5 | pass 6 | -------------------------------------------------------------------------------- /lore/models/sklearn.py: -------------------------------------------------------------------------------- 1 | import lore.models.base 2 | 3 | 4 | class Base(lore.models.base.Base): 5 | pass 6 | -------------------------------------------------------------------------------- /lore/models/xgboost.py: -------------------------------------------------------------------------------- 1 | import lore.models.base 2 | 3 | 4 | class Base(lore.models.base.Base): 5 | pass 6 | -------------------------------------------------------------------------------- /lore/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from collections import namedtuple 4 | 5 | Observations = namedtuple('Observations', 'x y') 6 | 7 | -------------------------------------------------------------------------------- /lore/pipelines/holdout.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from abc import ABCMeta, abstractmethod 4 | from collections import OrderedDict, Iterable 5 | import gc 6 | import logging 7 | import multiprocessing 8 | 9 | import lore 10 | from lore.env import require 11 | from lore.util import timer, timed 12 | from lore.pipelines import Observations 13 | 14 | require( 15 | lore.dependencies.NUMPY + 16 | lore.dependencies.PANDAS + 17 | lore.dependencies.SKLEARN 18 | ) 19 | import numpy 20 | import pandas 21 | from sklearn.model_selection import train_test_split 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class Base(object): 28 | __metaclass__ = ABCMeta 29 | 30 | test_size = 0.1 31 | 32 | def __init__(self): 33 | self.name = self.__module__ + '.' + self.__class__.__name__ 34 | self.stratify = None 35 | self.subsample = None 36 | self.split_seed = 1 37 | self.index = [] 38 | self.multiprocessing = False 39 | self.workers = None 40 | self._data = None 41 | self._encoders = None 42 | self._training_data = None 43 | self._test_data = None 44 | self._validation_data = None 45 | self._output_encoder = None 46 | self._encoded_training_data = None 47 | self._encoded_validation_data = None 48 | self._encoded_test_data = None 49 | 50 | def __getstate__(self): 51 | state = dict(self.__dict__) 52 | # bloat can be restored via self.__init__() + self.build() 53 | for bloat in [ 54 | '_data', 55 | '_training_data', 56 | '_test_data', 57 | '_validation_data', 58 | '_encoded_training_data', 59 | '_encoded_validation_data', 60 | '_encoded_test_data', 61 | ]: 62 | state[bloat] = None 63 | return state 64 | 65 | def __setstate__(self, dict): 66 | self.__dict__ = dict 67 | backward_compatible_defaults = { 68 | 'index': [], 69 | 'multiprocessing': False, 70 | 'workers': None, 71 | } 72 | for key, default in backward_compatible_defaults.items(): 73 | if key not in self.__dict__.keys(): 74 | self.__dict__[key] = default 75 | 76 | @abstractmethod 77 | def get_data(self): 78 | pass 79 | 80 | @abstractmethod 81 | def get_encoders(self): 82 | pass 83 | 84 | @abstractmethod 85 | def get_output_encoder(self): 86 | pass 87 | 88 | @property 89 | def encoders(self): 90 | if self._encoders is None: 91 | with timer('fit encoders'): 92 | self._encoders = self.get_encoders() 93 | 94 | # Ensure we have an iterable for all single encoder cases 95 | if not isinstance(self._encoders, Iterable): 96 | if len((self._encoders, )) == 1: 97 | self._encoders = (self._encoders, ) 98 | 99 | if self.multiprocessing: 100 | pool = multiprocessing.Pool(self.workers) 101 | results = [] 102 | for encoder in self._encoders: 103 | results.append(pool.apply_async(self.fit, (encoder, self.training_data))) 104 | self._encoders = [result.get() for result in results] 105 | 106 | else: 107 | for encoder in self._encoders: 108 | encoder.fit(self.training_data) 109 | 110 | return self._encoders 111 | 112 | @property 113 | def output_encoder(self): 114 | if self._output_encoder is None: 115 | with timer('fit output encoder'): 116 | self._output_encoder = self.get_output_encoder() 117 | self._output_encoder.fit(self.training_data) 118 | 119 | return self._output_encoder 120 | 121 | @property 122 | def training_data(self): 123 | if self._training_data is None: 124 | self._split_data() 125 | 126 | return self._training_data 127 | 128 | @property 129 | def validation_data(self): 130 | if self._validation_data is None: 131 | self._split_data() 132 | 133 | return self._validation_data 134 | 135 | @property 136 | def test_data(self): 137 | if self._test_data is None: 138 | self._split_data() 139 | 140 | return self._test_data 141 | 142 | @property 143 | def encoded_training_data(self): 144 | if not self._encoded_training_data: 145 | with timer('encode training data'): 146 | self._encoded_training_data = self.observations(self.training_data) 147 | 148 | return self._encoded_training_data 149 | 150 | @property 151 | def encoded_validation_data(self): 152 | if not self._encoded_validation_data: 153 | with timer('encode validation data'): 154 | self._encoded_validation_data = self.observations(self.validation_data) 155 | 156 | return self._encoded_validation_data 157 | 158 | @property 159 | def encoded_test_data(self): 160 | if not self._encoded_test_data: 161 | with timer('encode test data'): 162 | self._encoded_test_data = self.observations(self.test_data) 163 | 164 | return self._encoded_test_data 165 | 166 | def observations(self, data): 167 | return Observations(x=self.encode_x(data), y=self.encode_y(data)) 168 | 169 | @timed(logging.INFO) 170 | def encode_x(self, data): 171 | """ 172 | :param data: unencoded input dataframe 173 | :return: a dict with encoded values 174 | """ 175 | encoded = OrderedDict() 176 | if self.multiprocessing: 177 | pool = multiprocessing.Pool(self.workers) 178 | results = [] 179 | for encoder in self.encoders: 180 | results.append((encoder, pool.apply_async(self.transform, (encoder, data)))) 181 | 182 | for encoder, result in results: 183 | self.merged_transformed(encoded, encoder, result.get()) 184 | 185 | else: 186 | for encoder in self.encoders: 187 | self.merged_transformed(encoded, encoder, self.transform(encoder, data), append_twin=False) 188 | if encoder.twin: 189 | self.merged_transformed(encoded, encoder, self.transform(encoder, data, append_twin = True), append_twin=True) 190 | 191 | for column in self.index: 192 | encoded[column] = self.read_column(data, column) 193 | 194 | # Using a DataFrame as a container temporarily requires double the memory, 195 | # as pandas copies all data on __init__. This is justified by having a 196 | # type supported by all dependent libraries (heterogeneous dict is not) 197 | dataframe = pandas.DataFrame(encoded) 198 | if self.index: 199 | dataframe.set_index(self.index) 200 | return dataframe 201 | 202 | def fit(self, encoder, data): 203 | encoder.fit(data) 204 | return encoder 205 | 206 | def transform(self, encoder, data, append_twin=False): 207 | if append_twin: 208 | return encoder.transform(self.read_column(data, encoder.twin_column)) 209 | else: 210 | return encoder.transform(self.read_column(data, encoder.source_column)) 211 | 212 | @staticmethod 213 | def merged_transformed(encoded, encoder, transformed, append_twin=False): 214 | if hasattr(encoder, 'sequence_length'): 215 | for i in range(encoder.sequence_length): 216 | if isinstance(transformed, pandas.DataFrame): 217 | if append_twin: 218 | encoded[encoder.sequence_name(i, suffix="_twin")] = transformed.iloc[:, i] 219 | else: 220 | encoded[encoder.sequence_name(i)] = transformed.iloc[:, i] 221 | else: 222 | if append_twin: 223 | encoded[encoder.sequence_name(i, suffix="_twin")] = transformed[:, i] 224 | else: 225 | encoded[encoder.sequence_name(i)] = transformed[:, i] 226 | 227 | else: 228 | if append_twin: 229 | encoded[encoder.twin_name] = transformed 230 | else: 231 | encoded[encoder.name] = transformed 232 | 233 | 234 | 235 | @timed(logging.INFO) 236 | def encode_y(self, data): 237 | if self.output_encoder.source_column in data.columns: 238 | return self.output_encoder.transform(self.read_column(data, self._output_encoder.source_column)) 239 | else: 240 | return None 241 | 242 | @timed(logging.INFO) 243 | def decode(self, data): 244 | decoded = OrderedDict() 245 | for encoder in self.encoders: 246 | decoded[encoder.name.split('_', 1)[-1]] = encoder.reverse_transform(data[encoder.name]) 247 | return pandas.DataFrame(decoded) 248 | 249 | def read_column(self, data, column): 250 | """ 251 | Implemented so subclasses can overide handle different types of columnar data 252 | 253 | :param dataframe: 254 | :param column: 255 | :return: 256 | """ 257 | return data[column] 258 | 259 | @timed(logging.INFO) 260 | def _split_data(self): 261 | if self._data: 262 | return 263 | 264 | numpy.random.seed(self.split_seed) 265 | logger.debug('random seed set to: %i' % self.split_seed) 266 | 267 | self._data = self.get_data() 268 | gc.collect() 269 | if self.subsample: 270 | 271 | if self.stratify: 272 | logger.debug('subsampling stratified by `%s`: %s' % ( 273 | self.stratify, self.subsample)) 274 | ids = self._data[[self.stratify]].drop_duplicates() 275 | ids = ids.sample(self.subsample) 276 | self._data = pandas.merge(self._data, ids, on=self.stratify) 277 | else: 278 | logger.debug('subsampling rows: %s' % self.subsample) 279 | self._data = self._data.sample(self.subsample) 280 | gc.collect() 281 | 282 | if self.stratify: 283 | ids = self._data[self.stratify].drop_duplicates() 284 | 285 | train_ids, validate_ids = train_test_split( 286 | ids, 287 | test_size=self.test_size, 288 | random_state=self.split_seed 289 | ) 290 | gc.collect() 291 | train_ids, test_ids = train_test_split( 292 | train_ids, 293 | test_size=self.test_size, 294 | random_state=self.split_seed 295 | ) 296 | gc.collect() 297 | 298 | rows = self._data[self.stratify].values 299 | self._training_data = self._data.iloc[numpy.isin(rows, train_ids.values)] 300 | self._validation_data = self._data.iloc[numpy.isin(rows, validate_ids.values)] 301 | self._test_data = self._data.iloc[numpy.isin(rows, test_ids.values)] 302 | else: 303 | self._training_data, self._validation_data = train_test_split( 304 | self._data, 305 | test_size=self.test_size, 306 | random_state=self.split_seed 307 | ) 308 | 309 | self._training_data, self._test_data = train_test_split( 310 | self._training_data, 311 | test_size=self.test_size, 312 | random_state=self.split_seed 313 | ) 314 | gc.collect() 315 | 316 | self._data = None 317 | gc.collect() 318 | 319 | # It's import to reset these indexes after split so in case 320 | # these dataframes are copied, the missing split rows are 321 | # not re-materialized later full of nans. 322 | self._training_data.reset_index(drop=True, inplace=True) 323 | gc.collect() 324 | self._validation_data.reset_index(drop=True, inplace=True) 325 | gc.collect() 326 | self._test_data.reset_index(drop=True, inplace=True) 327 | gc.collect() 328 | 329 | logger.debug('training: %i | validation: %i | test: %i' % ( 330 | len(self._training_data), 331 | len(self._validation_data), 332 | len(self._test_data) 333 | )) 334 | -------------------------------------------------------------------------------- /lore/pipelines/time_series.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from abc import ABCMeta 4 | import logging 5 | 6 | import lore 7 | from lore.util import timed 8 | import lore.pipelines.holdout 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class Base(lore.pipelines.holdout.Base): 14 | __metaclass__ = ABCMeta 15 | 16 | def __init__(self, test_size=0.1, sort_by=None): 17 | super(Base, self).__init__() 18 | self.sort_by = sort_by 19 | self.test_size = test_size 20 | 21 | @timed(logging.INFO) 22 | def _split_data(self): 23 | if self._data: 24 | return 25 | 26 | logger.debug('No shuffle test train split') 27 | 28 | self._data = self.get_data() 29 | 30 | if self.sort_by: 31 | self._data = self._data.sort_values(by=self.sort_by, ascending=True) 32 | test_rows = int(len(self._data) * self.test_size) 33 | valid_rows = test_rows 34 | train_rows = int(len(self._data) - test_rows - valid_rows) 35 | self._training_data = self._data.iloc[:train_rows] 36 | self._validation_data = self._data.iloc[train_rows:train_rows+valid_rows] 37 | self._test_data = self._data.iloc[-test_rows:] 38 | -------------------------------------------------------------------------------- /lore/stores/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lore 3 | from lore.stores.disk import Disk 4 | from lore.stores.ram import Ram 5 | 6 | 7 | cache = Ram() 8 | query_cache = Disk(os.path.join(lore.env.DATA_DIR, 'query_cache')) 9 | 10 | 11 | def cached(func): 12 | global cache 13 | return _cached(func, cache) 14 | 15 | 16 | def query_cached(func): 17 | global query_cache 18 | return _cached(func, query_cache) 19 | 20 | 21 | def _cached(func, store): 22 | def wrapper(self, *args, **kwargs): 23 | cache = kwargs.pop('cache', False) 24 | if not cache: 25 | return func(self, *args, **kwargs) 26 | 27 | key = store.key(instance=self, caller=func, *args, **kwargs) 28 | if key not in store: 29 | store[key] = func(self, *args, **kwargs) 30 | return store[key] 31 | 32 | return wrapper 33 | -------------------------------------------------------------------------------- /lore/stores/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import hashlib 4 | import inspect 5 | 6 | 7 | class Base(object): 8 | __metaclass__ = ABCMeta 9 | 10 | @abstractmethod 11 | def __getitem__(self, key): 12 | pass 13 | 14 | @abstractmethod 15 | def __setitem__(self, key, value): 16 | pass 17 | 18 | @abstractmethod 19 | def __delitem__(self, key): 20 | pass 21 | 22 | @abstractmethod 23 | def __contains__(self, key): 24 | pass 25 | 26 | @abstractmethod 27 | def __len__(self): 28 | pass 29 | 30 | @abstractmethod 31 | def keys(self): 32 | pass 33 | 34 | @abstractmethod 35 | def values(self): 36 | pass 37 | 38 | @abstractmethod 39 | def batch_set(self): 40 | pass 41 | 42 | @abstractmethod 43 | def batch_get(self, data_dict): 44 | pass 45 | 46 | def key(self, *args, **kwargs): 47 | stack = inspect.stack() 48 | caller = kwargs.pop('caller', stack[-2]) 49 | instance = kwargs.pop('instance', self) 50 | 51 | return '.'.join(( 52 | instance.__module__, 53 | instance.__class__.__name__, 54 | caller.__code__.co_name, 55 | hashlib.sha1(str(args).encode('utf-8') + str(kwargs).encode('utf-8')).hexdigest() 56 | )) 57 | -------------------------------------------------------------------------------- /lore/stores/disk.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import pickle 4 | 5 | import lore 6 | from lore.stores.base import Base 7 | from lore.util import timer 8 | from lore.env import require 9 | 10 | require( 11 | lore.dependencies.PANDAS 12 | ) 13 | 14 | import pandas 15 | from datetime import datetime 16 | 17 | try: 18 | FileExistsError 19 | except NameError: 20 | FileExistsError = OSError 21 | 22 | 23 | class Disk(Base): 24 | EXTENSION = '.pickle' 25 | 26 | def __init__(self, dir): 27 | self.dir = dir 28 | self.limit = None 29 | if not os.path.exists(self.dir): 30 | try: 31 | os.makedirs(self.dir) 32 | except FileExistsError as ex: 33 | pass # race to create 34 | 35 | def __getitem__(self, key): 36 | if key in self: 37 | with timer('read %s' % key): 38 | with open(self._path(key), 'rb') as f: 39 | return pickle.load(f) 40 | return None 41 | 42 | def __setitem__(self, key, value): 43 | with timer('write %s' % key): 44 | if isinstance(value, pandas.core.frame.DataFrame): 45 | if len(value): 46 | for c in value.columns: 47 | first = value[c].iloc[0] 48 | if isinstance(first, datetime) and \ 49 | type(first.tzinfo).__name__.startswith('GMT'): 50 | value[c] = pandas.to_datetime(value[c], utc=True) 51 | 52 | with open(self._path(key), 'wb') as f: 53 | pickle.dump(value, f, pickle.HIGHEST_PROTOCOL) 54 | gc.collect() 55 | 56 | if self.limit is not None: 57 | if os.path.getsize(self._path(key)) > self.limit: 58 | raise MemoryError('disk cache limit exceeded by single key: %s' % key) 59 | 60 | with timer('evict %s' % key): 61 | while self.size() > self.limit: 62 | del self[self.lru()] 63 | gc.collect() 64 | 65 | def __delitem__(self, key): 66 | os.remove(self._path(key)) 67 | 68 | def __contains__(self, key): 69 | return os.path.isfile(self._path(key)) 70 | 71 | def __len__(self): 72 | return len(self.keys()) 73 | 74 | def batch_get(self, keys): 75 | result = {} 76 | for key in keys: 77 | result[key] = self.__getitem__(key) 78 | return result 79 | 80 | def batch_set(self, data_dict): 81 | for key, value in data_dict.items(): 82 | self.__setitem__(key, value) 83 | 84 | def size(self): 85 | return sum(os.path.getsize(f) for f in self.values()) 86 | 87 | def keys(self): 88 | return [self._key(f) for f in os.listdir(self.dir)] 89 | 90 | def values(self): 91 | return [os.path.join(self.dir, f) for f in os.listdir(self.dir)] 92 | 93 | def lru(self): 94 | files = sorted(self.values(), key=lambda f: os.stat(f).st_atime) 95 | 96 | if not files: 97 | return None 98 | 99 | return self._key(files[0]) 100 | 101 | def _path(self, key): 102 | return os.path.join(self.dir, key + self.EXTENSION) 103 | 104 | def _key(self, path): 105 | return os.path.basename(path)[0:-len(self.EXTENSION)] 106 | -------------------------------------------------------------------------------- /lore/stores/ram.py: -------------------------------------------------------------------------------- 1 | from lore.stores.base import Base 2 | 3 | 4 | class Ram(dict, Base): 5 | pass 6 | -------------------------------------------------------------------------------- /lore/stores/redis.py: -------------------------------------------------------------------------------- 1 | from lore.stores.base import Base 2 | 3 | 4 | class Redis(Base): 5 | def __init__(self, redis_conn): 6 | self.r = redis_conn 7 | 8 | def __getitem__(self, key): 9 | return self.r.get(key) 10 | 11 | def __setitem__(self, key, value): 12 | self.r.set(key, value) 13 | 14 | def __delitem__(self, key): 15 | self.r.delete(key) 16 | 17 | def __contains__(self, key): 18 | return self.r.get(key) != None 19 | 20 | def __len__(self): 21 | raise "Operation can be expensive. Aborting" 22 | 23 | def keys(self): 24 | raise "Operation can be expensive. Aborting" 25 | 26 | def values(self): 27 | raise "Operation can be expensive. Aborting" 28 | 29 | def batch_get(self, keys): 30 | return self.r.mget(keys) 31 | 32 | def batch_set(self, data_dict): 33 | self.r.mset(data_dict) 34 | 35 | -------------------------------------------------------------------------------- /lore/stores/s3.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import lore.io 5 | from lore.stores.base import Base 6 | from lore.util import timer 7 | 8 | 9 | # TODO 10 | class S3(Base): 11 | pass 12 | -------------------------------------------------------------------------------- /lore/tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instacart/lore/a14f65a96d0ea2513a35e424b4e16d948115b89c/lore/tasks/__init__.py -------------------------------------------------------------------------------- /lore/tasks/base.py: -------------------------------------------------------------------------------- 1 | class Base(object): 2 | pass 3 | -------------------------------------------------------------------------------- /lore/template/architecture.py.j2: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import lore\n", 20 | "from {{app_name}}.models.{{module_name}} import Keras\n", 21 | "model = Keras.load()" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "from IPython.display import SVG\n", 31 | "from keras.utils.vis_utils import model_to_dot\n", 32 | "\n", 33 | "SVG(model_to_dot(model.estimator.keras).create(prog='dot', format='svg'))" 34 | ] 35 | } 36 | ], 37 | "metadata": { 38 | "kernelspec": { 39 | "display_name": "{{app_name}}", 40 | "language": "python", 41 | "name": "{{app_name}}" 42 | }, 43 | "language_info": { 44 | "codemirror_mode": { 45 | "name": "ipython", 46 | "version": {{major_version}} 47 | }, 48 | "file_extension": ".py", 49 | "mimetype": "text/x-python", 50 | "name": "python", 51 | "nbconvert_exporter": "python", 52 | "pygments_lexer": "ipython3", 53 | "version": "{{full_version}}" 54 | } 55 | }, 56 | "nbformat": 4, 57 | "nbformat_minor": 2 58 | } -------------------------------------------------------------------------------- /lore/template/estimator.py.j2: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | {% if keras %} 4 | import lore.estimators.keras 5 | {% endif %} 6 | {% if xgboost %} 7 | import lore.estimators.xgboost 8 | {% endif %} 9 | {% if sklearn %} 10 | import lore.estimators.sklearn 11 | {% endif %} 12 | 13 | logger = logging.getLogger(__name__) 14 | {% if keras %} 15 | 16 | 17 | class Keras(lore.estimators.keras.{{base}}): 18 | pass 19 | {% endif %} 20 | {% if xgboost %} 21 | 22 | 23 | class XGBoost(lore.estimators.xgboost.{{base}}): 24 | pass 25 | {% endif %} 26 | {% if sklearn %} 27 | 28 | 29 | class SKLearn(lore.estimators.sklearn.{{base}}): 30 | pass 31 | {% endif %} 32 | -------------------------------------------------------------------------------- /lore/template/features.py.j2: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import matplotlib.pyplot\n", 20 | "import matplotlib\n", 21 | "matplotlib.style.use('seaborn')\n", 22 | "matplotlib.rc('axes', grid=True)\n", 23 | "%matplotlib inline\n", 24 | "%config InlineBackend.figure_format = 'retina'\n", 25 | "\n", 26 | "import lore\n", 27 | "from {{app_name}}.models.{{module_name}} import Keras\n", 28 | "model = Keras.load()" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "test = model.pipeline.encoded_test_data.x.copy()\n", 38 | "test['response'] = model.pipeline.test_data[model.pipeline.output_encoder.column]\n", 39 | "test['color'] = 'gold'\n", 40 | "\n", 41 | "predict = model.pipeline.encoded_test_data.x.copy()\n", 42 | "predict['response'] = model.predict(model.pipeline.test_data)\n", 43 | "predict['color'] = 'blue'\n", 44 | "\n", 45 | "data = test.append(predict)\n", 46 | "\n", 47 | "def plot_encoder(name):\n", 48 | " stats = data.groupby(['color', name]).agg({'response': ['mean', 'size']}).reset_index()\n", 49 | " stats.columns = ['color', name, 'response', 'population']\n", 50 | " stats['population'] = (stats['population'] / stats['population'].max() * 1000).clip(lower=3)\n", 51 | " stats.plot.scatter(x=name, y='response', s=stats['population'], figsize=(16, 9), c=stats['color'], alpha=0.75) \n", 52 | "\n", 53 | "for encoder in model.pipeline.encoders:\n", 54 | " if hasattr(encoder, 'sequence_length'):\n", 55 | " for i in range(encoder.sequence_length):\n", 56 | " plot_encoder(encoder.sequence_name(i))\n", 57 | " else:\n", 58 | " plot_encoder(encoder.name)" 59 | ] 60 | } 61 | ], 62 | "metadata": { 63 | "kernelspec": { 64 | "display_name": "{{app_name}}", 65 | "language": "python", 66 | "name": "{{app_name}}" 67 | }, 68 | "language_info": { 69 | "codemirror_mode": { 70 | "name": "ipython", 71 | "version": {{major_version}} 72 | }, 73 | "file_extension": ".py", 74 | "mimetype": "text/x-python", 75 | "name": "python", 76 | "nbconvert_exporter": "python", 77 | "pygments_lexer": "ipython3", 78 | "version": "{{full_version}}" 79 | } 80 | }, 81 | "nbformat": 4, 82 | "nbformat_minor": 2 83 | } 84 | -------------------------------------------------------------------------------- /lore/template/init/.env.template: -------------------------------------------------------------------------------- 1 | LORE_ENV=development 2 | DATABASE_URL=postgres://localhost:5432/development 3 | REDIS_URL=redis://localhost:6379 4 | ROLLBAR_ACCESS_TOKEN=00000000000000000000000000000000 5 | -------------------------------------------------------------------------------- /lore/template/init/.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Setuptools distribution folder. 5 | dist/ 6 | build/ 7 | 8 | # Python egg metadata, regenerated from source files by setuptools. 9 | *.egg-info 10 | 11 | # IDE files 12 | .idea 13 | 14 | # Jupyter 15 | notebooks/**/.ipynb_checkpoints/* 16 | 17 | # Lore local developement files 18 | .env 19 | logs/* 20 | data/* 21 | models/* 22 | tests/models/* 23 | -------------------------------------------------------------------------------- /lore/template/init/.keras/keras.json: -------------------------------------------------------------------------------- 1 | { 2 | "floatx": "float32", 3 | "image_dim_ordering": "tf", 4 | "backend": "tensorflow", 5 | "epsilon": 1e-07 6 | } 7 | -------------------------------------------------------------------------------- /lore/template/init/Procfile: -------------------------------------------------------------------------------- 1 | web: lore server 2 | -------------------------------------------------------------------------------- /lore/template/init/README.rst: -------------------------------------------------------------------------------- 1 | {{app_name}} 2 | ========== 3 | 4 | System Setup 5 | ------------ 6 | 7 | 1) Get your system setup 8 | 9 | .. code:: 10 | 11 | $ lore install 12 | 13 | 2) Set correct variables in `.env` 14 | 15 | .. code:: 16 | 17 | $ cp .env.template .env 18 | $ edit .env 19 | 20 | Running 21 | ------- 22 | 23 | To run locally: 24 | 25 | .. code:: 26 | 27 | $ lore server 28 | 29 | Testing 30 | ------- 31 | 32 | To test locally: 33 | 34 | .. code:: 35 | 36 | $ lore test 37 | 38 | Training 39 | -------- 40 | 41 | To train locally: 42 | 43 | .. code:: 44 | 45 | $ lore fit MODEL 46 | 47 | Deploying 48 | --------- 49 | 50 | .. code:: 51 | 52 | $ git push heroku master 53 | -------------------------------------------------------------------------------- /lore/template/init/app/__init__.py: -------------------------------------------------------------------------------- 1 | import lore 2 | import lore.stores 3 | import os 4 | 5 | lore.env.APP = __name__ 6 | lore.stores.query_cache.limit = int(os.environ.get('LORE_QUERY_CACHE_LIMIT', 10000000000)) 7 | -------------------------------------------------------------------------------- /lore/template/init/app/estimators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instacart/lore/a14f65a96d0ea2513a35e424b4e16d948115b89c/lore/template/init/app/estimators/__init__.py -------------------------------------------------------------------------------- /lore/template/init/app/extracts/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instacart/lore/a14f65a96d0ea2513a35e424b4e16d948115b89c/lore/template/init/app/extracts/.gitkeep -------------------------------------------------------------------------------- /lore/template/init/app/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instacart/lore/a14f65a96d0ea2513a35e424b4e16d948115b89c/lore/template/init/app/models/__init__.py -------------------------------------------------------------------------------- /lore/template/init/app/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instacart/lore/a14f65a96d0ea2513a35e424b4e16d948115b89c/lore/template/init/app/pipelines/__init__.py -------------------------------------------------------------------------------- /lore/template/init/config/aws.cfg: -------------------------------------------------------------------------------- 1 | [IAM] 2 | role: $IAM_ROLE 3 | 4 | [ACCESS_KEY] 5 | id: $AWS_ACCESS_KEY_ID 6 | secret: $AWS_SECRET_ACCESS_KEY 7 | 8 | [BUCKET] 9 | name: lore 10 | -------------------------------------------------------------------------------- /lore/template/init/config/database.cfg: -------------------------------------------------------------------------------- 1 | [MAIN] 2 | url: $DATABASE_URL 3 | -------------------------------------------------------------------------------- /lore/template/init/notebooks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instacart/lore/a14f65a96d0ea2513a35e424b4e16d948115b89c/lore/template/init/notebooks/.gitkeep -------------------------------------------------------------------------------- /lore/template/init/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instacart/lore/a14f65a96d0ea2513a35e424b4e16d948115b89c/lore/template/init/tests/__init__.py -------------------------------------------------------------------------------- /lore/template/init/tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instacart/lore/a14f65a96d0ea2513a35e424b4e16d948115b89c/lore/template/init/tests/unit/__init__.py -------------------------------------------------------------------------------- /lore/template/model.py.j2: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | {% if keras %} 4 | import lore.models.keras 5 | {% endif %} 6 | {% if xgboost %} 7 | import lore.models.xgboost 8 | {% endif %} 9 | {% if sklearn %} 10 | import lore.models.sklearn 11 | {% endif %} 12 | 13 | import {{app_name}}.pipelines.{{module_name}} 14 | import {{app_name}}.estimators.{{module_name}} 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | {% if keras %} 19 | 20 | 21 | class Keras(lore.models.keras.Base): 22 | def __init__(self, pipeline=None, estimator=None): 23 | super(Keras, self).__init__( 24 | {{app_name}}.pipelines.{{module_name}}.Holdout(), 25 | {{app_name}}.estimators.{{module_name}}.Keras() 26 | ) 27 | {% endif %} 28 | {% if xgboost %} 29 | 30 | 31 | class XGBoost(lore.models.xgboost.Base): 32 | def __init__(self, pipeline=None, estimator=None): 33 | super(XGBoost, self).__init__( 34 | {{app_name}}.pipelines.{{module_name}}.Holdout(), 35 | {{app_name}}.estimators.{{module_name}}.XGBoost() 36 | ) 37 | {% endif %} 38 | {% if sklearn %} 39 | 40 | 41 | class SKLearn(lore.models.sklearn.Base): 42 | def __init__(self, pipeline=None, estimator=None): 43 | super(SKLearn, self).__init__( 44 | {{app_name}}.pipelines.{{module_name}}.Holdout(), 45 | {{app_name}}.estimators.{{module_name}}.SKLearn() 46 | ) 47 | {% endif %} 48 | -------------------------------------------------------------------------------- /lore/template/pipeline.py.j2: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import lore.pipelines.holdout 4 | from lore.util import timed 5 | 6 | import pandas 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class Holdout(lore.pipelines.holdout.Base): 13 | @timed(logging.INFO) 14 | def get_data(self): 15 | """ 16 | 17 | :return: pandas.DataFrame of raw data to be split into train/validation/tess sets 18 | """ 19 | # TODO return the full data set as a dataframe, one column per encoder 20 | return pandas.DataFrame() 21 | 22 | @timed(logging.INFO) 23 | def get_encoders(self): 24 | # TODO return a tuple of instances from lore.encoders 25 | return (None,) 26 | 27 | 28 | @timed(logging.INFO) 29 | def get_output_encoder(self): 30 | # TODO return a single instance from lore.encoders 31 | return None 32 | 33 | -------------------------------------------------------------------------------- /lore/template/test.py.j2: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import {{app_name}}.models.{{module_name}} 4 | 5 | class Test{{class_name}}(unittest.TestCase): 6 | 7 | def test_fit(self): 8 | {% if keras %} 9 | 10 | model = {{app_name}}.models.{{module_name}}.Keras() 11 | model.pipeline.subsample = 1000 12 | model.fit(epochs=1) 13 | self.assertTrue(True) 14 | {% endif %} 15 | {% if xgboost %} 16 | 17 | model = {{app_name}}.models.{{module_name}}.XGBoost() 18 | model.pipeline.subsample = 1000 19 | model.fit() 20 | self.assertTrue(True) 21 | {% endif %} 22 | {% if sklearn %} 23 | 24 | model = {{app_name}}.models.{{module_name}}.SKLearn() 25 | model.pipeline.subsample = 1000 26 | model.fit() 27 | self.assertTrue(True) 28 | {% endif %} 29 | 30 | -------------------------------------------------------------------------------- /lore/www/__init__.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import importlib 3 | import json 4 | import logging 5 | import pkgutil 6 | 7 | 8 | import lore 9 | import lore.util 10 | import lore.env 11 | from lore.env import require 12 | from lore.util import timer 13 | 14 | require( 15 | lore.dependencies.PANDAS + 16 | lore.dependencies.FLASK 17 | ) 18 | import pandas 19 | from flask import Flask, request 20 | 21 | 22 | app = Flask(lore.env.APP) 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | @app.route('/') 28 | def index(): 29 | names = str([name for _, name, _ in pkgutil.iter_modules([lore.env.APP + '/' + 'models'])]) 30 | return 'Hello %s!' % lore.env.APP + '\n' + names 31 | 32 | 33 | for module_finder, module_name, _ in pkgutil.iter_modules([lore.env.APP + '/' + 'models']): 34 | module = importlib.import_module(lore.env.APP + '.models.' + module_name) 35 | for class_name, member in inspect.getmembers(module): 36 | if not (inspect.isclass(member) and issubclass(member, lore.models.base.Base)): 37 | continue 38 | 39 | qualified_name = module_name + '.' + class_name 40 | with timer('load %s' % qualified_name): 41 | best = member.load() 42 | 43 | def predict(): 44 | logger.debug(request.args) 45 | data = {arg: request.args.getlist(arg) for arg in request.args.keys()} 46 | try: 47 | data = pandas.DataFrame(data) 48 | except ValueError: 49 | return 'Malformed data!', 400 50 | 51 | logger.debug(data) 52 | try: 53 | result = best.predict(data) 54 | except KeyError as ex: 55 | return 'Missing data!', 400 56 | return json.dumps(result.tolist()), 200 57 | 58 | predict.__name__ = best.name + '.predict' 59 | 60 | rule = '/' + qualified_name + '/predict.json' 61 | logger.info('Adding url rule for prediction: %s' % rule) 62 | app.add_url_rule(rule, view_func=predict) 63 | -------------------------------------------------------------------------------- /notebooks/names.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%%javascript\n", 10 | "var command = \"nb_name = '\" + IPython.notebook.notebook_path + \"'; nb_name = nb_name.split('/')[-1]\";\n", 11 | "IPython.notebook.kernel.execute(command);" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import requests\n", 21 | "import datetime\n", 22 | "from lxml import etree\n", 23 | "import csv\n", 24 | "import lore\n", 25 | "import os\n", 26 | "import pandas\n", 27 | "lore_dir = os.path.join(os.path.dirname(nb_name), '..')\n", 28 | "os.chdir(lore_dir)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "Download popular names from social security administration" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": { 42 | "collapsed": true 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "url = 'https://www.ssa.gov/cgi-bin/popularnames.cgi'\n", 47 | "first_available = 1880\n", 48 | "most_recent = datetime.datetime.now().year - 1\n", 49 | "\n", 50 | "\n", 51 | "dir = os.path.join(lore.env.data_dir, 'usa_names')\n", 52 | "if not os.path.exists(dir):\n", 53 | " os.makedirs(dir)\n", 54 | "\n", 55 | "years = {}\n", 56 | "for year in range(first_available, most_recent + 1):\n", 57 | " path = os.path.join(dir, str(year) + '.csv')\n", 58 | " if not os.path.exists(path):\n", 59 | " response = requests.post(url, data={'year': year, 'top': 1000, 'number': 'n'})\n", 60 | " html = response.text\n", 61 | " tree = etree.HTML(html)\n", 62 | " with open(path, 'w') as file:\n", 63 | " writer = csv.writer(file)\n", 64 | " writer.writerow(['rank', 'male_name', 'male_count', 'female_name', 'female_count'])\n", 65 | " for row in tree.xpath('body/table[2]/tr/td[2]/table/tr'):\n", 66 | " tds = row.xpath('td')\n", 67 | " if tds:\n", 68 | " writer.writerow([td.text.replace(',', '') for td in tds if td.text])\n", 69 | " years[year] = pandas.DataFrame.from_csv(path)\n" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "url = 'https://www.ssa.gov/oact/STATS/table4c6.html'\n", 79 | "response = requests.get(url)\n", 80 | "html = response.text\n", 81 | "tree = etree.HTML(html)\n", 82 | "path = os.path.join(lore.env.data_dir, 'actuary.csv')\n", 83 | "tree.xpath('//*[@id=\"content\"]/div/div[2]/div/table[1]/tbody/tr[2]/td/table/tbody/tr')" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "Get life expectancy by sex" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": { 104 | "collapsed": true 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "url = 'https://www.ssa.gov/oact/STATS/table4c6.html'\n", 109 | "response = requests.get(url)\n", 110 | "html = response.text\n", 111 | "tree = etree.HTML(html)\n", 112 | "path = os.path.join(lore.env.data_data, 'actuary.csv')\n", 113 | "with open(path, 'w') as file:\n", 114 | " writer = csv.writer(file)\n", 115 | " writer.writerow(['age','male death probability','male lives','male life expectancy','female death probability','female lives','female life expectency'\n", 116 | "])\n", 117 | " for row in tree.xpath('//*[@id=\"content\"]/div/div[2]/div/table[1]/tbody/tr[2]/td/table/tbody/tr'):\n", 118 | " tds = row.xpath('td')\n", 119 | " if tds:\n", 120 | " writer.writerow([td.text.replace(',', '') for td in tds if td.text])\n", 121 | " \n", 122 | "\n", 123 | "actuary = pandas.DataFrame.from_csv(os.path.join(lore.env.data_dir, 'actuarial', 'table.csv'))\n", 124 | "male_deaths = actuary['male death probability'].apply(lambda x: 1-x).cumprod()\n", 125 | "female_deaths = actuary['female death probability'].apply(lambda x: 1-x).cumprod()" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": { 132 | "collapsed": true 133 | }, 134 | "outputs": [], 135 | "source": [ 136 | "year_data = {}\n", 137 | "for year in years:\n", 138 | " path = os.path.join(dir, str(year) + '.csv')\n", 139 | " year_data[year] = pandas.DataFrame.from_csv(path)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": { 146 | "collapsed": true 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "from datetime import date, timedelta\n", 151 | "from collections import defaultdict\n", 152 | "this_year = date(date.today().year, 1, 1)\n", 153 | "year_delta = timedelta(days=365.24)\n", 154 | "name_years = defaultdict(lambda: defaultdict(dict))\n", 155 | "living_name_years = defaultdict(lambda: defaultdict(dict))\n", 156 | "\n", 157 | "for year, data in year_data.items():\n", 158 | " age = (this_year - date(year, 1, 1)) // year_delta\n", 159 | " for row in data.itertuples():\n", 160 | " name_years[row[1]][year]['male'] = row[2]\n", 161 | " name_years[row[3]][year]['female'] = row[4]\n", 162 | " if age < 120:\n", 163 | " living_name_years[row[1]][year]['male'] = row[2] * male_deaths.iat[age]\n", 164 | " living_name_years[row[3]][year]['female'] = row[4] * female_deaths.iat[age]\n", 165 | " else:\n", 166 | " living_name_years[row[1]][year]['male'] = 0\n", 167 | " living_name_years[row[3]][year]['female'] = 0" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "name_stats = defaultdict(lambda: {'total': 0, 'male': 0, 'female': 0, 'mean_age': 0})\n", 177 | "for name in name_years:\n", 178 | " stats = name_stats[name]\n", 179 | " for year, sexes in living_name_years[name].items():\n", 180 | " age = (this_year - date(year, 1, 1)) / year_delta\n", 181 | " male = sexes.get('male', 0)\n", 182 | " female = sexes.get('female', 0)\n", 183 | " stats['male'] += male\n", 184 | " stats['female'] += female\n", 185 | " stats['total'] += (male + female)\n", 186 | " stats['mean_age'] += (male + female) * age\n", 187 | " if stats['total'] > 0:\n", 188 | " stats['mean_age'] = stats['mean_age'] / stats['total'] \n", 189 | " stats['sex'] = stats['male'] / stats['total']\n", 190 | " else:\n", 191 | " stats['mean_age'] = stats['sex'] = 0\n", 192 | " name_stats[name] = stats\n", 193 | "name_stats['Piper'] " 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": { 200 | "collapsed": true 201 | }, 202 | "outputs": [], 203 | "source": [ 204 | "from matplotlib import pyplot\n", 205 | "from matplotlib.patches import Patch\n", 206 | "\n", 207 | "def plot_sexes(name):\n", 208 | " years = range(first_available, most_recent + 1)\n", 209 | " males = [float(name_years[name][year].get('male', 0)) for year in years]\n", 210 | " females = [float(name_years[name][year].get('female', 0)) for year in years]\n", 211 | " living_males = [float(living_name_years[name][year].get('male', 0)) for year in years]\n", 212 | " living_females = [float(living_name_years[name][year].get('female', 0)) for year in years]\n", 213 | " pyplot.figure(dpi=200)\n", 214 | " pyplot.plot(years, males, label=\"male\", color='#ADD8E6')\n", 215 | " pyplot.plot(years, females, label=\"female\", color='#FFC0CB')\n", 216 | " pyplot.plot(years, living_males, label=\"living male\", color='#6666FF')\n", 217 | " pyplot.plot(years, living_females, label=\"living female\", color='#FF6666')\n", 218 | " pyplot.axvline(x=(year - name_stats[name]['mean_age']), label=('mean age: %3.1f' % name_stats[name]['mean_age']), color='green')\n", 219 | " pyplot.plot()\n", 220 | " pyplot.xlabel('year')\n", 221 | " pyplot.ylabel('births')\n", 222 | " pyplot.title('Babies Named ' + name)\n", 223 | " pyplot.grid(True)\n", 224 | " pyplot.legend(loc='best')\n", 225 | " pyplot.show()" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "plot_sexes('Montana')" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "plot_sexes('Natalie')" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "plot_sexes('Josh')" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "metadata": { 259 | "collapsed": true 260 | }, 261 | "outputs": [], 262 | "source": [ 263 | "with open(os.path.join(dir, 'names.csv'), 'w') as file:\n", 264 | " writer = csv.writer(file)\n", 265 | " writer.writerow(('name', 'male', 'mean_age', 'sample_size'))\n", 266 | " for name, stats in name_stats.items():\n", 267 | " if (stats['male'] + stats['female']) > 0:\n", 268 | " writer.writerow((name.lower(), stats['male'] / (stats['male'] + stats['female']), stats['mean_age'], stats['total']))\n" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": null, 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "with open(os.path.join(dir, 'names.csv'), 'r') as file:\n", 278 | " reader = csv.reader(file)\n", 279 | " for line in reader:\n", 280 | " print(line)\n" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": { 287 | "collapsed": true 288 | }, 289 | "outputs": [], 290 | "source": [] 291 | } 292 | ], 293 | "metadata": { 294 | "kernelspec": { 295 | "display_name": "referrals", 296 | "language": "python", 297 | "name": "referrals" 298 | }, 299 | "language_info": { 300 | "codemirror_mode": { 301 | "name": "ipython", 302 | "version": 3 303 | }, 304 | "file_extension": ".py", 305 | "mimetype": "text/x-python", 306 | "name": "python", 307 | "nbconvert_exporter": "python", 308 | "pygments_lexer": "ipython3", 309 | "version": "3.6.2" 310 | } 311 | }, 312 | "nbformat": 4, 313 | "nbformat_minor": 2 314 | } 315 | -------------------------------------------------------------------------------- /release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | TAG=v`cat lore/__init__.py | grep '__version__ =' | awk '{ print $3}'` 4 | 5 | git tag $TAG 6 | git push origin $TAG 7 | 8 | rm -r build/ dist/ ./*.egg-info/ 9 | 10 | python setup.py sdist bdist_wheel 11 | 12 | twine upload -r pypi dist/* 13 | 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | -------------------------------------------------------------------------------- /runtime.txt: -------------------------------------------------------------------------------- 1 | python-3.6.6 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=1 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from io import open 2 | from setuptools import setup 3 | 4 | import sys 5 | sys.lore_no_env = True 6 | 7 | import lore 8 | 9 | 10 | def readme(): 11 | with open('README.rst', 'r', encoding='utf-8') as f: 12 | return f.read() 13 | 14 | 15 | setup( 16 | name='lore', 17 | version=lore.__version__, 18 | description='a framework for building and using data science', 19 | long_description=readme(), 20 | classifiers=[ 21 | lore.__status__, 22 | 'License :: OSI Approved :: MIT License', 23 | 'Programming Language :: Python :: 2', 24 | 'Programming Language :: Python :: 2.7', 25 | 'Programming Language :: Python :: 3', 26 | 'Programming Language :: Python :: 3.5', 27 | 'Programming Language :: Python :: 3.6', 28 | 'Topic :: Scientific/Engineering', 29 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 30 | 'Environment :: Console', 31 | ], 32 | python_requires='>=3.0, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4', 33 | keywords='machine learning framework tensorflow airflow', 34 | url='http://github.com/instacart/lore', 35 | author=lore.__author__, 36 | author_email=lore.__email__, 37 | license='MIT', 38 | packages=[ 39 | 'lore', 40 | 'lore.estimators', 41 | 'lore.estimators.holt_winters', 42 | 'lore.features', 43 | 'lore.io', 44 | 'lore.models', 45 | 'lore.metadata', 46 | 'lore.pipelines', 47 | 'lore.stores', 48 | 'lore.tasks', 49 | 'lore.www', 50 | 'lore.template.init.app', 51 | 'lore.template.init.app.estimators', 52 | 'lore.template.init.app.models', 53 | 'lore.template.init.app.pipelines', 54 | 'lore.template.init.tests.unit', 55 | ], 56 | zip_safe=True, 57 | package_data={ 58 | '': [ 59 | 'data/names.csv', 60 | 'template/init/app/extracts/*', 61 | 'template/init/config/*', 62 | 'template/init/notebooks/*', 63 | 'template/*', 64 | 'template/.*', 65 | 'template/**/*', 66 | 'template/**/.*', 67 | 'template/.**/*', 68 | 'template/.**/.*' 69 | ] 70 | }, 71 | entry_points={ 72 | 'console_scripts': [ 73 | 'lore=lore.__main__:main', 74 | ], 75 | }, 76 | ) 77 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instacart/lore/a14f65a96d0ea2513a35e424b4e16d948115b89c/tests/__init__.py -------------------------------------------------------------------------------- /tests/lore_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import unittest 3 | import lore 4 | 5 | 6 | class TestTruth(unittest.TestCase): 7 | def test_truth(self): 8 | self.assertTrue(True) 9 | -------------------------------------------------------------------------------- /tests/mocks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instacart/lore/a14f65a96d0ea2513a35e424b4e16d948115b89c/tests/mocks/__init__.py -------------------------------------------------------------------------------- /tests/mocks/features.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from lore.features.base import BaseFeatureExporter 3 | from lore.features.s3 import S3FeatureExporter 4 | import pandas as pd 5 | import inflection 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class UserWarehouseSearchesFeature(S3FeatureExporter): 10 | def __init__(self): 11 | super(UserWarehouseSearchesFeature, self).__init__() 12 | 13 | def name(self): 14 | return inflection.underscore(UserWarehouseSearchesFeature.__name__) 15 | 16 | def key(self): 17 | return ['user_id', 'warehouse_id'] 18 | 19 | def serialization(self): 20 | return 'csv' 21 | 22 | def get_data(self): 23 | return pd.DataFrame({'user_id': [1, 1, 2], 'warehouse_id': [1, 2, 1], 'searches': [10, 20, 30], 'conversions': [1, 2, 3]}) 24 | 25 | # write entry in ddatabase w/ metadata 26 | # lore.io.s3[mymetadatakey] = self (upload, pickle/json/csv) 27 | 28 | 29 | 30 | ### Ideas 31 | 32 | # feature = UserWarehouseSearchesFeature() 33 | # print(feature.get_data()) 34 | # print(feature) 35 | # feature.publish() 36 | # print(feature.features_as_kv()) 37 | # feature.distribute(Redis(lore.io.redis_conn)) 38 | # building.publish() 39 | # 40 | # downloaded = UserWarehouseSearchesFeature(version=1) 41 | # downloaded._data # filled from s3 previous publish 42 | # 43 | # downloaded.distribute(lore.io.redis) 44 | # 45 | # building.metadata # dataframe 46 | # building.metadata.to_sql('features', lore.io.customers) 47 | # 48 | # lore.io.customers.insert('features', building.metadata) 49 | # lore.io.customers.replace('features', building.metadata) 50 | -------------------------------------------------------------------------------- /tests/mocks/models_keras.py: -------------------------------------------------------------------------------- 1 | import tests 2 | 3 | import lore.models 4 | import lore.models.keras 5 | import lore.estimators.keras 6 | import tests.mocks.pipelines 7 | 8 | 9 | class Keras(lore.models.keras.Base): 10 | def __init__( 11 | self, 12 | embed_size=10 13 | ): 14 | super(Keras, self).__init__( 15 | tests.mocks.pipelines.Xor(), 16 | lore.estimators.keras.Base( 17 | batch_size=1024, 18 | embed_size=embed_size, 19 | hidden_layers=1, 20 | hidden_width=100, 21 | loss='binary_crossentropy', 22 | monitor='val_loss', 23 | cudnn=False 24 | ) 25 | ) 26 | 27 | 28 | class KerasSingle(lore.models.keras.Base): 29 | def __init__( 30 | self, 31 | type 32 | ): 33 | super(KerasSingle, self).__init__( 34 | tests.mocks.pipelines.XorSingle(type=type), 35 | lore.estimators.keras.Base(loss='binary_crossentropy') 36 | ) 37 | 38 | 39 | class NestedKeras(lore.models.keras.Base): 40 | def __init__( 41 | self, 42 | embed_size=10 43 | ): 44 | super(NestedKeras, self).__init__( 45 | tests.mocks.pipelines.MockNestedData(), 46 | lore.estimators.keras.Base( 47 | batch_size=1024, 48 | embed_size=embed_size, 49 | hidden_layers=1, 50 | hidden_width=100, 51 | loss='binary_crossentropy', 52 | monitor='loss', 53 | cudnn=False 54 | ) 55 | ) 56 | 57 | 58 | class KerasMulti(lore.models.keras.Base): 59 | def __init__(self): 60 | super(KerasMulti, self).__init__( 61 | tests.mocks.pipelines.XorMulti(), 62 | lore.estimators.keras.MultiClassifier( 63 | batch_size=1024, 64 | embed_size=10, 65 | hidden_layers=1, 66 | hidden_width=100 67 | ) 68 | ) 69 | 70 | 71 | class BinaryClassifier(lore.models.keras.Base): 72 | def __init__( 73 | self, 74 | embed_size=10 75 | ): 76 | super(BinaryClassifier, self).__init__( 77 | tests.mocks.pipelines.TwinData(test_size=0.5), 78 | lore.estimators.keras.BinaryClassifier( 79 | batch_size=1024, 80 | embed_size=embed_size, 81 | hidden_layers=1, 82 | hidden_width=100, 83 | cudnn=False 84 | ) 85 | ) 86 | 87 | 88 | class SaimeseTwinsClassifier(lore.models.keras.Base): 89 | def __init__( 90 | self, 91 | embed_size=10, 92 | sequence_embed_size=50, 93 | ): 94 | super(SaimeseTwinsClassifier, self).__init__( 95 | tests.mocks.pipelines.TwinDataWithVaryingEmbedScale(test_size=0.5), 96 | lore.estimators.keras.BinaryClassifier( 97 | batch_size=1024, 98 | embed_size=embed_size, 99 | sequence_embed_size=sequence_embed_size, 100 | sequence_embedding='lstm', 101 | hidden_layers=1, 102 | hidden_width=100, 103 | cudnn=False 104 | ) 105 | ) 106 | 107 | def before_fit(self, *args, **kwargs): 108 | self.called_before_fit = True 109 | 110 | def after_fit(self, *args, **kwargs): 111 | self.called_after_fit = True 112 | 113 | def before_predict(self, *args, **kwargs): 114 | self.called_before_predict = True 115 | 116 | def after_predict(self, *args, **kwargs): 117 | self.called_after_predict = True 118 | 119 | def before_evaluate(self, *args, **kwargs): 120 | self.called_before_evaluate = True 121 | 122 | def after_evaluate(self, *args, **kwargs): 123 | self.called_after_evaluate = True 124 | 125 | def before_score(self, *args, **kwargs): 126 | self.called_before_score = True 127 | 128 | def after_score(self, *args, **kwargs): 129 | self.called_after_score = True 130 | 131 | 132 | -------------------------------------------------------------------------------- /tests/mocks/models_other.py: -------------------------------------------------------------------------------- 1 | from sklearn import svm 2 | import tests 3 | 4 | import lore.models.sklearn 5 | import lore.models.xgboost 6 | import lore.models.naive 7 | import lore.estimators.sklearn 8 | import lore.estimators.xgboost 9 | import lore.estimators.naive 10 | 11 | 12 | class XGBoostBinaryClassifier(lore.models.xgboost.Base): 13 | def __init__(self): 14 | super(XGBoostBinaryClassifier, self).__init__( 15 | tests.mocks.pipelines.Xor(), 16 | lore.estimators.xgboost.BinaryClassifier() 17 | ) 18 | 19 | 20 | class XGBoostRegression(lore.models.xgboost.Base): 21 | def __init__(self): 22 | super(XGBoostRegression, self).__init__( 23 | tests.mocks.pipelines.Xor(), 24 | lore.estimators.xgboost.Regression() 25 | ) 26 | 27 | 28 | class XGBoostRegressionWithPredictionLogging(lore.models.xgboost.Base): 29 | def __init__(self): 30 | super(XGBoostRegressionWithPredictionLogging, self).__init__( 31 | tests.mocks.pipelines.Xor(), 32 | lore.estimators.xgboost.Regression() 33 | ) 34 | 35 | 36 | class SVM(lore.models.sklearn.Base): 37 | def __init__(self): 38 | super(SVM, self).__init__( 39 | tests.mocks.pipelines.Xor(), 40 | lore.estimators.sklearn.Base( 41 | svm.SVC() 42 | ) 43 | ) 44 | 45 | def before_fit(self, *args, **kwargs): 46 | self.called_before_fit = True 47 | 48 | def after_fit(self, *args, **kwargs): 49 | self.called_after_fit = True 50 | 51 | def before_predict(self, *args, **kwargs): 52 | self.called_before_predict = True 53 | 54 | def after_predict(self, *args, **kwargs): 55 | self.called_after_predict = True 56 | 57 | def before_evaluate(self, *args, **kwargs): 58 | self.called_before_evaluate = True 59 | 60 | def after_evaluate(self, *args, **kwargs): 61 | self.called_after_evaluate = True 62 | 63 | def before_score(self, *args, **kwargs): 64 | self.called_before_score = True 65 | 66 | def after_score(self, *args, **kwargs): 67 | self.called_after_score = True 68 | 69 | 70 | class OneHotBinaryClassifier(lore.models.xgboost.Base): 71 | def __init__(self): 72 | super(OneHotBinaryClassifier, self).__init__( 73 | tests.mocks.pipelines.OneHotPipeline(), 74 | lore.estimators.xgboost.BinaryClassifier()) 75 | 76 | def before_fit(self, *args, **kwargs): 77 | self.called_before_fit = True 78 | 79 | def after_fit(self, *args, **kwargs): 80 | self.called_after_fit = True 81 | 82 | def before_predict(self, *args, **kwargs): 83 | self.called_before_predict = True 84 | 85 | def after_predict(self, *args, **kwargs): 86 | self.called_after_predict = True 87 | 88 | def before_evaluate(self, *args, **kwargs): 89 | self.called_before_evaluate = True 90 | 91 | def after_evaluate(self, *args, **kwargs): 92 | self.called_after_evaluate = True 93 | 94 | def before_score(self, *args, **kwargs): 95 | self.called_before_score = True 96 | 97 | def after_score(self, *args, **kwargs): 98 | self.called_after_score = True 99 | 100 | 101 | class NaiveBinaryClassifier(lore.models.naive.Base): 102 | def __init__(self): 103 | super(NaiveBinaryClassifier, self).__init__( 104 | tests.mocks.pipelines.NaivePipeline(), 105 | lore.estimators.naive.BinaryClassifier()) 106 | 107 | def before_fit(self, *args, **kwargs): 108 | self.called_before_fit = True 109 | 110 | def after_fit(self, *args, **kwargs): 111 | self.called_after_fit = True 112 | 113 | def before_predict(self, *args, **kwargs): 114 | self.called_before_predict = True 115 | 116 | def after_predict(self, *args, **kwargs): 117 | self.called_after_predict = True 118 | 119 | def before_evaluate(self, *args, **kwargs): 120 | self.called_before_evaluate = True 121 | 122 | def after_evaluate(self, *args, **kwargs): 123 | self.called_after_evaluate = True 124 | 125 | def before_score(self, *args, **kwargs): 126 | self.called_before_score = True 127 | 128 | def after_score(self, *args, **kwargs): 129 | self.called_after_score = True 130 | -------------------------------------------------------------------------------- /tests/mocks/pipelines.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import pandas 4 | import sqlalchemy 5 | 6 | from lore.encoders import Unique, Pass, Token, Boolean, Enum, Continuous, OneHot, NestedUnique, NestedNorm 7 | from lore.transformers import DateTime 8 | import lore.io 9 | import lore.pipelines.holdout 10 | import lore.pipelines.iterative 11 | import lore.pipelines.time_series 12 | 13 | 14 | class Xor(lore.pipelines.holdout.Base): 15 | def get_data(self): 16 | return pandas.DataFrame({ 17 | 'a': [0, 1, 0, 1] * 1000, 18 | 'b': [0, 0, 1, 1] * 1000, 19 | 'words': ['is false', 'is true', 'is not false', 'is not true'] * 1000, 20 | 'xor': [0, 1, 1, 0] * 1000 21 | }) 22 | 23 | def get_encoders(self): 24 | return ( 25 | Unique('a'), 26 | Unique('b'), 27 | Token('words') 28 | ) 29 | 30 | def get_output_encoder(self): 31 | return Pass('xor') 32 | 33 | 34 | class XorSingle(Xor): 35 | def __init__( 36 | self, 37 | type 38 | ): 39 | super(XorSingle, self).__init__() 40 | self.type = type 41 | 42 | def get_encoders(self): 43 | if self.type == 'tuple': 44 | return ( 45 | Unique('a'), 46 | ) 47 | elif self.type == 'len1': 48 | return ( 49 | Unique('a') 50 | ) 51 | elif self.type == 'single': 52 | return Unique('a') 53 | 54 | 55 | class XorMulti(Xor): 56 | def get_output_encoder(self): 57 | return OneHot('xor') 58 | 59 | 60 | class MockData(lore.pipelines.time_series.Base): 61 | def get_data(self): 62 | return pandas.DataFrame({ 63 | 'a': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 64 | 'b': [21, 22, 23, 24, 25, 26, 27, 28, 29, 30], 65 | 'target': [1, 0, 1, 0, 1, 0, 1, 0, 1, 0] 66 | }) 67 | 68 | def get_encoders(self): 69 | return ( 70 | Unique('a'), 71 | Unique('b'), 72 | ) 73 | 74 | def get_output_encoder(self): 75 | return Pass('target') 76 | 77 | 78 | class TwinData(lore.pipelines.time_series.Base): 79 | def get_data(self): 80 | return pandas.DataFrame({ 81 | 'a': [100, 200, 300], 82 | 'a_twin': [300, 500, 100], 83 | 'b': [500, 100, 700], 84 | 'b_twin': [100, 400, 500], 85 | 'c': ["orange", "orange juice", "organic orange juice"], 86 | 'c_twin': ["navel orange", "orange juice", "organic orange juice"], 87 | 'user_id': [1,2,3], 88 | 'price': [1.99, 2.99, 3.99], 89 | 'target': [1, 0, 1] 90 | }) 91 | 92 | def get_encoders(self): 93 | return ( 94 | Unique('a', twin=True), 95 | Unique('b', twin=True), 96 | Token('c', twin=True, sequence_length=3), 97 | Unique('user_id'), 98 | Pass('price') 99 | ) 100 | 101 | def get_output_encoder(self): 102 | return Pass('target') 103 | 104 | 105 | class TwinDataWithVaryingEmbedScale(lore.pipelines.time_series.Base): 106 | def get_data(self): 107 | return pandas.DataFrame({ 108 | 'a': [100, 200, 300], 109 | 'a_twin': [300, 500, 100], 110 | 'b': [500, 100, 700], 111 | 'b_twin': [100, 400, 500], 112 | 'c': ["orange", "orange juice", "organic orange juice"], 113 | 'c_twin': ["navel orange", "orange juice", "organic orange juice"], 114 | 'user_id': [1,2,3], 115 | 'price': [1.99, 2.99, 3.99], 116 | 'target': [1, 0, 1] 117 | }) 118 | 119 | def get_encoders(self): 120 | return ( 121 | Unique('a', embed_scale=3, twin=True), 122 | Unique('b', embed_scale=4, twin=True), 123 | Token('c', embed_scale=5, twin=True, sequence_length=3), 124 | Unique('user_id'), 125 | Pass('price') 126 | ) 127 | 128 | def get_output_encoder(self): 129 | return Pass('target') 130 | 131 | class Users(lore.pipelines.iterative.Base): 132 | dataframe = pandas.DataFrame({ 133 | 'id': range(1000), 134 | 'first_name': [str(i) for i in range(1000)], 135 | 'last_name': [str(i % 100) for i in range(1000)], 136 | 'subscriber': [i % 2 == 0 for i in range(1000)], 137 | 'signup_at': [datetime.datetime.now()] * 1000 138 | }) 139 | sqlalchemy_table = sqlalchemy.Table( 140 | 'tests_low_memory_users', lore.io.main.metadata, 141 | sqlalchemy.Column('id', sqlalchemy.Integer, primary_key=True), 142 | sqlalchemy.Column('first_name', sqlalchemy.String(50)), 143 | sqlalchemy.Column('last_name', sqlalchemy.String(50)), 144 | sqlalchemy.Column('subscriber', sqlalchemy.Boolean()), 145 | sqlalchemy.Column('signup_at', sqlalchemy.DateTime()), 146 | ) 147 | sqlalchemy_table.drop(checkfirst=True) 148 | lore.io.main.metadata.create_all() 149 | lore.io.main.insert('tests_low_memory_users', dataframe) 150 | 151 | def _split_data(self): 152 | self.connection.execute('drop table if exists {name};'.format(name=self.table)) 153 | super(Users, self)._split_data() 154 | 155 | def get_data(self): 156 | return lore.io.main.dataframe(sql='select * from tests_low_memory_users', chunksize=2) 157 | 158 | def get_encoders(self): 159 | return ( 160 | Unique('id'), 161 | Unique('first_name'), 162 | Unique('last_name'), 163 | Boolean('subscriber'), 164 | Enum(DateTime('signup_at', 'dayofweek')), 165 | ) 166 | 167 | def get_output_encoder(self): 168 | return Pass('subscriber') 169 | 170 | 171 | class MockNestedData(lore.pipelines.time_series.Base): 172 | def get_data(self): 173 | return pandas.DataFrame({ 174 | 'a': [['a', 'b'], ['a', 'b', 'c'], ['c', 'd'], ['a', 'e'], None], 175 | 'b': [[0, 1, 2], None, [2, 3, 4, 5], [1], [-1, 10]], 176 | 'target': [1, 0, 1, 0, 1] 177 | }) 178 | 179 | def get_encoders(self): 180 | return ( 181 | NestedUnique('a'), 182 | NestedNorm('b'), 183 | ) 184 | 185 | def get_output_encoder(self): 186 | return Pass('target') 187 | 188 | 189 | class OneHotPipeline(lore.pipelines.holdout.Base): 190 | def get_data(self): 191 | return pandas.DataFrame({ 192 | 'a': [1, 1, 2, 3] * 1000, 193 | 'b': [0, 0, 1, 1] * 1000, 194 | 'words': ['is false', 'is true', 'is not false', 'is not true'] * 1000, 195 | 'xor': [0, 1, 1, 0] * 1000 196 | }) 197 | 198 | def get_encoders(self): 199 | return ( 200 | OneHot('a', minimum_occurrences=1001, compressed=True), 201 | OneHot('b'), 202 | Token('words') 203 | ) 204 | 205 | def get_output_encoder(self): 206 | return Pass('xor') 207 | 208 | 209 | class NaivePipeline(lore.pipelines.holdout.Base): 210 | def get_data(self): 211 | return pandas.DataFrame({ 212 | 'x': [1, 2, 3]*1000, 213 | 'y': [0, 1, 1]*1000 214 | }) 215 | 216 | def get_encoders(self): 217 | return(Pass('x')) 218 | 219 | def get_output_encoder(self): 220 | return Pass('y') 221 | -------------------------------------------------------------------------------- /tests/mocks/tasks.py: -------------------------------------------------------------------------------- 1 | import lore.tasks.base 2 | 3 | class EchoTask(lore.tasks.base.Base): 4 | def main(self, arg1=None): 5 | pass 6 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from lore import env 4 | import lore 5 | lore.env.require(lore.dependencies.TEST) 6 | -------------------------------------------------------------------------------- /tests/unit/io/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instacart/lore/a14f65a96d0ea2513a35e424b4e16d948115b89c/tests/unit/io/__init__.py -------------------------------------------------------------------------------- /tests/unit/io/test_connection.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import unicode_literals 3 | 4 | import unittest 5 | from threading import Thread 6 | import datetime 7 | import time 8 | import math 9 | 10 | import lore.io.connection 11 | import sqlalchemy 12 | from sqlalchemy import event 13 | from sqlalchemy.engine import Engine 14 | import pandas 15 | 16 | import lore 17 | 18 | 19 | calls = 0 20 | @event.listens_for(Engine, "after_cursor_execute") 21 | def count_sql_calls(conn, cursor, statement, parameters, context, executemany): 22 | global calls 23 | calls += 1 24 | 25 | 26 | class TestConnection(unittest.TestCase): 27 | @classmethod 28 | def setUpClass(self): 29 | self.table = sqlalchemy.Table( 30 | 'tests_users', lore.io.main.metadata, 31 | sqlalchemy.Column('id', sqlalchemy.Integer, primary_key=True), 32 | sqlalchemy.Column('first_name', sqlalchemy.String(50), nullable=False), 33 | sqlalchemy.Column('last_name', sqlalchemy.String(50), nullable=False), 34 | sqlalchemy.Column('nullable', sqlalchemy.Float(), nullable=True), 35 | sqlalchemy.Column('zero', sqlalchemy.Integer(), nullable=False), 36 | sqlalchemy.Column('pi', sqlalchemy.Float(), nullable=False), 37 | sqlalchemy.Index('index_tests_users_first_name_last_name', 'first_name', 'last_name', unique=True), 38 | sqlalchemy.Index('long_name_long_name_long_name_long_name_long_name_long_name_63_', 'first_name', unique=True), 39 | ) 40 | lore.io.main.metadata.create_all() 41 | 42 | def setUp(self): 43 | self.table.drop() 44 | self.dataframe = pandas.DataFrame({ 45 | 'id': range(1001), 46 | 'first_name': [str(i) for i in range(1001)], 47 | 'last_name': [str(i) for i in range(1001)], 48 | 'nullable': [i if i % 2 == 0 else None for i in range(1001)], 49 | 'zero': [0] * 1001, 50 | 'pi': [math.pi] * 1001, 51 | }) 52 | 53 | lore.io.main.metadata.create_all() 54 | 55 | def test_connection(self): 56 | self.assertTrue(hasattr(lore.io, 'main')) 57 | 58 | def test_replace(self): 59 | lore.io.main.replace(self.table.name, self.dataframe) 60 | selected = lore.io.main.dataframe(sql='select * from tests_users') 61 | self.assertEqual(self.dataframe['first_name'].tolist(), selected['first_name'].tolist()) 62 | self.assertEqual(self.dataframe['last_name'].tolist(), selected['last_name'].tolist()) 63 | 64 | def test_insert_bulk(self): 65 | global calls 66 | calls = 0 67 | lore.io.main.insert(self.table.name, self.dataframe) 68 | self.assertEqual(calls, 0) 69 | 70 | def test_insert_batches(self): 71 | global calls 72 | calls = 0 73 | lore.io.main.insert(self.table.name, self.dataframe, batch_size=100) 74 | self.assertEqual(calls, 0) 75 | 76 | def test_multiple_statements(self): 77 | users = lore.io.main.select(sql=''' 78 | insert into tests_users(first_name, last_name, zero, pi) 79 | values ('1', '2', 0, 3.14); 80 | insert into tests_users(first_name, last_name, zero, pi) 81 | values ('3', '4', 0, 3.14); 82 | insert into tests_users(first_name, last_name, zero, pi) 83 | values ('4', '5', 0, 3.14); 84 | select * from tests_users; 85 | ''') 86 | self.assertEqual(len(users), 3) 87 | 88 | def test_persistent_temp_tables(self): 89 | lore.io.main.execute(sql='create temporary table tests_persistence(id integer not null primary key)') 90 | lore.io.main.execute(sql='insert into tests_persistence values (1), (2), (3)') 91 | temps = lore.io.main.select(sql='select count(*) from tests_persistence') 92 | self.assertEqual(temps[0][0], 3) 93 | 94 | def test_connection_temp_table_isolation(self): 95 | lore.io.main.execute(sql='create temporary table tests_temp(id integer not null primary key)') 96 | lore.io.main.execute(sql='insert into tests_temp values (1), (2), (3)') 97 | lore.io.main_two.execute(sql='create temporary table tests_temp(id integer not null primary key)') 98 | lore.io.main_two.execute(sql='insert into tests_temp values (1), (2), (3)') 99 | 100 | temps = lore.io.main.select(sql='select count(*) from tests_temp') 101 | temps_two = lore.io.main_two.select(sql='select count(*) from tests_temp') 102 | self.assertEqual(temps[0][0], 3) 103 | self.assertEqual(temps_two[0][0], 3) 104 | 105 | def test_connection_autocommit_isolation(self): 106 | lore.io.main.execute(sql='drop table if exists tests_autocommit') 107 | lore.io.main.execute(sql='create table tests_autocommit(id integer not null primary key)') 108 | lore.io.main.execute(sql='insert into tests_autocommit values (1), (2), (3)') 109 | temps_two = lore.io.main_two.select(sql='select count(*) from tests_autocommit') 110 | self.assertEqual(temps_two[0][0], 3) 111 | 112 | def test_transaction_rollback(self): 113 | lore.io.main.execute(sql='drop table if exists tests_autocommit') 114 | lore.io.main.execute(sql='create table tests_autocommit(id integer not null primary key)') 115 | 116 | lore.io.main.execute(sql='insert into tests_autocommit values (0)') 117 | with self.assertRaises(sqlalchemy.exc.IntegrityError): 118 | with lore.io.main as transaction: 119 | transaction.execute(sql='insert into tests_autocommit values (1), (2), (3)') 120 | transaction.execute(sql='insert into tests_autocommit values (1), (2), (3)') 121 | 122 | inserted = lore.io.main_two.select(sql='select count(*) from tests_autocommit')[0][0] 123 | 124 | self.assertEqual(inserted, 1) 125 | 126 | def test_out_of_transaction_does_not_block_concurrent_writes(self): 127 | lore.io.main.execute(sql='drop table if exists tests_autocommit') 128 | lore.io.main.execute(sql='create table tests_autocommit(id integer not null primary key)') 129 | 130 | priors = [] 131 | posts = [] 132 | thrown = [] 133 | 134 | def insert(delay=0): 135 | try: 136 | priors.append(lore.io.main.select(sql='select count(*) from tests_autocommit')[0][0]) 137 | lore.io.main.execute(sql='insert into tests_autocommit values (1), (2), (3)') 138 | posts.append(lore.io.main.select(sql='select count(*) from tests_autocommit')[0][0]) 139 | time.sleep(delay) 140 | except sqlalchemy.exc.IntegrityError as ex: 141 | thrown.append(True) 142 | 143 | slow = Thread(target=insert, args=(1,)) 144 | fast = Thread(target=insert, args=(0,)) 145 | 146 | slow.start() 147 | time.sleep(0.5) 148 | fast.start() 149 | 150 | fast.join() 151 | fast_done = datetime.datetime.now() 152 | slow.join() 153 | slow_done = datetime.datetime.now() 154 | 155 | self.assertEqual(priors, [0, 3]) 156 | self.assertEqual(posts, [3]) 157 | self.assertEqual(thrown, [True]) 158 | self.assertAlmostEqual((slow_done - fast_done).total_seconds(), 0.5, 1) 159 | 160 | def test_close(self): 161 | lore.io.main.execute(sql='create temporary table tests_close(id integer not null primary key)') 162 | lore.io.main.execute(sql='insert into tests_close values (1), (2), (3)') 163 | lore.io.main.close() 164 | reopened = lore.io.main.select(sql='select 1') 165 | self.assertEquals(reopened, [(1,)]) 166 | with self.assertRaises(sqlalchemy.exc.ProgrammingError): 167 | lore.io.main.select(sql='select count(*) from tests_close') 168 | 169 | def test_reconnect_and_retry(self): 170 | original_execute = lore.io.main._connection.execute 171 | 172 | def raise_dbapi_error_on_first_call(sql, bindings): 173 | lore.io.main._connection.execute = original_execute 174 | e = lore.io.connection.Psycopg2OperationalError('server closed the connection unexpectedly. This probably means the server terminated abnormally before or while processing the request.') 175 | raise sqlalchemy.exc.DBAPIError('select 1', [], e, True) 176 | 177 | exceptions = lore.env.STDOUT_EXCEPTIONS 178 | lore.env.STDOUT_EXCEPTIONS = False 179 | connection = lore.io.main._connection 180 | lore.io.main._connection.execute = raise_dbapi_error_on_first_call 181 | 182 | result = lore.io.main.select(sql='select 1') 183 | lore.env.STDOUT_EXCEPTIONS = exceptions 184 | 185 | self.assertNotEquals(connection, lore.io.main._connection) 186 | self.assertEquals(result[0][0], 1) 187 | 188 | def test_tuple_interpolation(self): 189 | lore.io.main.execute(sql='create temporary table tests_interpolation(id integer not null primary key)') 190 | lore.io.main.execute(sql='insert into tests_interpolation values (1), (2), (3)') 191 | temps = lore.io.main.select(sql='select * from tests_interpolation where id in {ids}', ids=(1, 2, 3)) 192 | self.assertEqual(len(temps), 3) 193 | 194 | def test_reconnect_and_retry_on_expired_connection(self): 195 | original_execute = lore.io.main._connection.execute 196 | 197 | def raise_snowflake_programming_error_on_first_call(sql, bindings): 198 | lore.io.main._connection.execute = original_execute 199 | e = lore.io.connection.SnowflakeProgrammingError('Authentication token has expired. The user must authenticate again') 200 | raise sqlalchemy.exc.DBAPIError('select 1', [], e, True) 201 | 202 | exceptions = lore.env.STDOUT_EXCEPTIONS 203 | lore.env.STDOUT_EXCEPTIONS = False 204 | connection = lore.io.main._connection 205 | lore.io.main._connection.execute = raise_snowflake_programming_error_on_first_call 206 | 207 | result = lore.io.main.select(sql='select 1') 208 | lore.env.STDOUT_EXCEPTIONS = exceptions 209 | 210 | self.assertNotEquals(connection, lore.io.main._connection) 211 | self.assertEquals(result[0][0], 1) 212 | 213 | -------------------------------------------------------------------------------- /tests/unit/io/test_io.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import unicode_literals 3 | 4 | import unittest 5 | 6 | import lore.io 7 | 8 | import os 9 | 10 | 11 | class TestRemoteFromLocalPaths(unittest.TestCase): 12 | def test_works_dir_is_removed(self): 13 | local = os.path.join(lore.env.WORK_DIR, 'README.rst') 14 | self.assertIsNotNone(lore.env.WORK_DIR) 15 | self.assertEqual(lore.io.remote_from_local(local), '/README.rst') 16 | 17 | def test_relative_is_ok(self): 18 | local = 'README.rst' 19 | self.assertEqual(lore.io.remote_from_local(local), 'README.rst') 20 | 21 | def test_absolute_is_ok(self): 22 | local = '/README.rst' 23 | self.assertEqual(lore.io.remote_from_local(local), '/README.rst') 24 | 25 | 26 | class TestPrefixRemoteRoot(unittest.TestCase): 27 | def test_absolute(self): 28 | path = '/README.rst' 29 | self.assertEqual(lore.io.prefix_remote_root(path), 'test/README.rst') 30 | 31 | def test_relative(self): 32 | path = 'README.rst' 33 | self.assertEqual(lore.io.prefix_remote_root(path), 'test/README.rst') 34 | 35 | def test_is_absolute_pre_prefixed_safe(self): 36 | path = '/test/README.rst' 37 | self.assertEqual(lore.io.prefix_remote_root(path), 'test/README.rst') 38 | 39 | def test_is_relative_pre_prefixed_safe(self): 40 | path = 'test/README.rst' 41 | self.assertEqual(lore.io.prefix_remote_root(path), 'test/README.rst') 42 | 43 | def test_is_not_short_sighted(self): 44 | path = 'test_not_env/README.rst' 45 | self.assertEqual(lore.io.prefix_remote_root(path), 'test/test_not_env/README.rst') 46 | 47 | def test_does_not_double_slash(self): 48 | path = '/test_not_env/README.rst' 49 | self.assertEqual(lore.io.prefix_remote_root(path), 'test/test_not_env/README.rst') 50 | -------------------------------------------------------------------------------- /tests/unit/test_env.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import unittest 3 | import lore 4 | import lore.env 5 | 6 | 7 | class TestEnv(unittest.TestCase): 8 | def test_initialization(self): 9 | self.assertEqual(lore.env.NAME, lore.env.TEST) 10 | 11 | def test_canonicalize_package_name(self): 12 | self.assertEqual(lore.env.normalize_package_name("zope.interface"), lore.env.normalize_package_name("zope-interface")) 13 | 14 | -------------------------------------------------------------------------------- /tests/unit/test_estimators.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | class TestBase(unittest.TestCase): 4 | pass 5 | -------------------------------------------------------------------------------- /tests/unit/test_features.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | #from moto import mock_s3 4 | import boto3 5 | import pandas 6 | 7 | from lore.io import download 8 | 9 | from tests.mocks.features import UserWarehouseSearchesFeature 10 | 11 | 12 | class TestFeatures(unittest.TestCase): 13 | 14 | #@mock_s3 15 | def xtest_s3_features(self): 16 | s3 = boto3.resource('s3') 17 | # We need to create the bucket since this is all in Moto's 'virtual' AWS account 18 | s3.create_bucket(Bucket='lore-test') 19 | 20 | user_warehouse_feature = UserWarehouseSearchesFeature() 21 | user_warehouse_feature.publish() 22 | 23 | # temp_path = download(user_warehouse_feature.data_path(), cache=False) 24 | 25 | # fetched_data = pandas.read_csv(temp_path) 26 | # self.assertTrue(len(user_warehouse_feature.get_data()) == 3) 27 | # self.assertTrue(user_warehouse_feature.get_data().equals(fetched_data)) 28 | -------------------------------------------------------------------------------- /tests/unit/test_main.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import sys 3 | from lore.util import suppress_stdout 4 | 5 | 6 | class TestImport(unittest.TestCase): 7 | def test_import(self): 8 | # make sure this at least gets loaded 9 | import lore.__main__ 10 | self.assertTrue(True) 11 | 12 | 13 | class TestTask(unittest.TestCase): 14 | def test_task(self): 15 | import lore.__main__ 16 | 17 | args = ('task', 'tests.mocks.tasks.EchoTask', '--arg1', 'true') 18 | with suppress_stdout(): 19 | if sys.version_info[0] == 2: 20 | lore.__main__.main(args) 21 | self.assertTrue(True) 22 | else: 23 | with self.assertLogs('lore.__main__') as log: 24 | lore.__main__.main(args) 25 | self.assertEqual(log.output, [ 26 | "INFO:lore.__main__:starting task: " + 27 | "tests.mocks.tasks.EchoTask {'arg1': 'true'}" 28 | ]) 29 | -------------------------------------------------------------------------------- /tests/unit/test_metadata.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tests.mocks.models_other 3 | import lore.metadata 4 | import datetime 5 | 6 | 7 | def truncate_table(table_name): 8 | if lore.io.metadata.adapter == 'postgres': 9 | lore.io.metadata.execute('TRUNCATE {} RESTART IDENTITY CASCADE'.format(table_name)) 10 | else: 11 | lore.io.metadata.execute('DELETE FROM {}'.format(table_name)) 12 | 13 | 14 | def truncate_metadata_tables(): 15 | truncate_table('fittings') 16 | truncate_table('commits') 17 | truncate_table('predictions') 18 | truncate_table('snapshots') 19 | 20 | 21 | def setUpModule(): 22 | truncate_metadata_tables() 23 | 24 | 25 | class TestCrud(unittest.TestCase): 26 | def test_lifecycle(self): 27 | commit = lore.metadata.Commit.create(sha='abc') 28 | self.assertEqual(commit.__class__, lore.metadata.Commit) 29 | self.assertIsNotNone(commit.sha) 30 | 31 | commit.created_at = datetime.datetime.now() 32 | commit.save() 33 | 34 | all = lore.metadata.Commit.all() 35 | self.assertEqual(len(all), 1) 36 | 37 | first = lore.metadata.Commit.first() 38 | self.assertEqual(first.sha, commit.sha) 39 | 40 | last = lore.metadata.Commit.last() 41 | self.assertEqual(first.sha, last.sha) 42 | 43 | commit.delete() 44 | 45 | first = lore.metadata.Commit.first() 46 | self.assertIsNone(first) 47 | 48 | 49 | class TestFitting(unittest.TestCase): 50 | @classmethod 51 | def setUpClass(cls): 52 | cls.model = tests.mocks.models_other.XGBoostRegressionWithPredictionLogging() 53 | cls.df = cls.model.pipeline.training_data 54 | 55 | def test_model_fit(self): 56 | self.model.fit() 57 | fitting = lore.metadata.Fitting.last() 58 | self.assertEqual(fitting.id, self.model.fitting.id) 59 | self.assertEqual(fitting.model, self.model.name) 60 | self.assertEqual(fitting.id, self.model.last_fitting().id) 61 | 62 | 63 | class TestPredictionLogging(unittest.TestCase): 64 | @classmethod 65 | def setUpClass(cls): 66 | cls.model = tests.mocks.models_other.XGBoostRegressionWithPredictionLogging() 67 | cls.df = cls.model.pipeline.training_data 68 | 69 | def test_prediction_logging(self): 70 | self.model.fit() 71 | self.model.predict(self.df, log_predictions=True, key_cols=['a', 'b']) 72 | prediction_metadata = lore.metadata.Prediction.first(fitting_id=self.model.fitting.id) 73 | self.assertIsNotNone(prediction_metadata) 74 | -------------------------------------------------------------------------------- /tests/unit/test_models_keras.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import tests.mocks.models_keras 4 | import scipy.stats 5 | import numpy 6 | 7 | 8 | class TestKeras(unittest.TestCase): 9 | def test_lifecycle(self): 10 | model = tests.mocks.models_keras.Keras() 11 | model.fit(epochs=1) 12 | model.save() 13 | 14 | loaded = tests.mocks.models_keras.Keras.load() 15 | self.assertEqual(loaded.fitting, model.fitting) 16 | 17 | def test_before_after_hooks(self): 18 | model = tests.mocks.models_keras.SaimeseTwinsClassifier() 19 | model.fit(epochs=1, test=True, score=True) 20 | model.predict(model.pipeline.test_data) 21 | self.assertTrue(model.called_before_fit) 22 | self.assertTrue(model.called_after_fit) 23 | self.assertTrue(model.called_before_predict) 24 | self.assertTrue(model.called_after_predict) 25 | self.assertTrue(model.called_before_evaluate) 26 | self.assertTrue(model.called_after_evaluate) 27 | self.assertTrue(model.called_before_score) 28 | self.assertTrue(model.called_after_score) 29 | 30 | def test_hyper_param_search(self): 31 | model = tests.mocks.models_keras.Keras() 32 | result = model.hyper_parameter_search( 33 | {'embed_size': scipy.stats.randint(low=1, high=10)}, 34 | n_iter=2, 35 | fit_params={'epochs': 2} 36 | ) 37 | self.assertEqual(model.estimator, result.best_estimator_) 38 | 39 | def test_lstm_embeddings(self): 40 | model = tests.mocks.models_keras.Keras() 41 | model.estimator.sequence_embedding = 'lstm' 42 | model.fit(epochs=1) 43 | assert True 44 | 45 | def test_gru_embeddings(self): 46 | model = tests.mocks.models_keras.Keras() 47 | model.estimator.sequence_embedding = 'gru' 48 | model.fit(epochs=1) 49 | assert True 50 | 51 | def test_rnn_embeddings(self): 52 | model = tests.mocks.models_keras.Keras() 53 | model.estimator.sequence_embedding = 'simple_rnn' 54 | model.fit(epochs=1) 55 | assert True 56 | 57 | def test_flat_embeddings(self): 58 | model = tests.mocks.models_keras.Keras() 59 | model.estimator.sequence_embedding = 'flatten' 60 | model.fit(epochs=1) 61 | assert True 62 | 63 | def test_towers(self): 64 | model = tests.mocks.models_keras.Keras() 65 | model.estimator.towers = 2 66 | model.fit(epochs=1) 67 | assert True 68 | 69 | def test_short_names(self): 70 | model = tests.mocks.models_keras.Keras() 71 | model.estimator.short_names = True 72 | model.estimator.build() 73 | assert True 74 | 75 | def test_batch_norm(self): 76 | model = tests.mocks.models_keras.Keras() 77 | model.estimator.batch_norm = True 78 | model.estimator.build() 79 | assert True 80 | 81 | def test_kernel_initializer(self): 82 | model = tests.mocks.models_keras.Keras() 83 | model.estimator.kernel_initializer = 'he_uniform' 84 | model.estimator.build() 85 | assert True 86 | 87 | 88 | class TestKerasSingle(unittest.TestCase): 89 | def test_single_encoder_a(self): 90 | model = tests.mocks.models_keras.KerasSingle(type='tuple') 91 | model.estimator.build() 92 | 93 | def test_single_encoder_b(self): 94 | model = tests.mocks.models_keras.KerasSingle(type='len1') 95 | model.estimator.build() 96 | 97 | def test_single_encoder_c(self): 98 | model = tests.mocks.models_keras.KerasSingle(type='single') 99 | model.estimator.build() 100 | 101 | 102 | class TestNestedKeras(unittest.TestCase): 103 | def test_lifecycle(self): 104 | model = tests.mocks.models_keras.NestedKeras() 105 | model.fit(epochs=1) 106 | model.save() 107 | 108 | loaded = tests.mocks.models_keras.NestedKeras.load() 109 | self.assertEqual(loaded.fitting, model.fitting) 110 | 111 | 112 | class TestKerasMulti(unittest.TestCase): 113 | def test_multi(self): 114 | model = tests.mocks.models_keras.KerasMulti() 115 | model.fit(epochs=1) 116 | assert True 117 | 118 | 119 | class TestSiameseArchitectureBinaryClassifier(unittest.TestCase): 120 | 121 | def test_siamese_architecture_twin_sequence_pair_shapes(self): 122 | model = tests.mocks.models_keras.SaimeseTwinsClassifier() 123 | model.fit() 124 | model.save() 125 | 126 | keras_model = model.estimator.keras 127 | twin_layers = [l.name for l in keras_model.layers if "twin" in l.name] 128 | 129 | for twin_layer_name in twin_layers: 130 | original_layer_name = twin_layer_name.replace("_twin", "") 131 | siamese_original_layer = keras_model.get_layer(original_layer_name) 132 | siamese_twin_layer = keras_model.get_layer(twin_layer_name) 133 | self.assertEqual(siamese_twin_layer.input_shape, siamese_original_layer.input_shape) 134 | self.assertEqual(siamese_twin_layer.output_shape, siamese_original_layer.output_shape) 135 | 136 | 137 | class TestBinaryClassifier(unittest.TestCase): 138 | def test_lifecycle(self): 139 | model = tests.mocks.models_keras.BinaryClassifier() 140 | model.estimator.sequence_embedding = 'lstm' 141 | model.fit() 142 | model.save() 143 | # loaded = tests.mocks.models_other.BinaryClassifier.load() 144 | # self.assertEqual(loaded.fitting, model.fitting) 145 | 146 | def test_rnn_embeddings(self): 147 | model = tests.mocks.models_keras.BinaryClassifier() 148 | model.estimator.sequence_embedding = 'simple_rnn' 149 | model.fit(epochs=1) 150 | assert True 151 | 152 | def test_flatten_embeddings(self): 153 | model = tests.mocks.models_keras.BinaryClassifier() 154 | model.estimator.sequence_embedding = 'flatten' 155 | model.fit(epochs=1) 156 | assert True 157 | -------------------------------------------------------------------------------- /tests/unit/test_models_other.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tests.mocks.models_other 3 | 4 | import numpy 5 | 6 | class TestXGBoostRegression(unittest.TestCase): 7 | def test_lifecycle(self): 8 | model = tests.mocks.models_other.XGBoostRegression() 9 | model.fit() 10 | model.save() 11 | 12 | loaded = tests.mocks.models_other.XGBoostRegression.load() 13 | self.assertEqual(loaded.fitting, model.fitting) 14 | 15 | 16 | class TestXGBoostBinaryClassifier(unittest.TestCase): 17 | def test_lifecycle(self): 18 | model = tests.mocks.models_other.XGBoostBinaryClassifier() 19 | model.fit() 20 | model.save() 21 | 22 | loaded = tests.mocks.models_other.XGBoostBinaryClassifier.load() 23 | self.assertEqual(loaded.fitting, model.fitting) 24 | 25 | def test_probs(self): 26 | model = tests.mocks.models_other.XGBoostBinaryClassifier() 27 | model.fit() 28 | model.predict_proba(model.pipeline.test_data) 29 | assert True 30 | 31 | 32 | class TestSKLearn(unittest.TestCase): 33 | def test_lifecycle(self): 34 | model = tests.mocks.models_other.SVM() 35 | model.fit() 36 | model.save() 37 | 38 | loaded = tests.mocks.models_other.SVM.load() 39 | self.assertEqual(loaded.fitting, model.fitting) 40 | 41 | def test_before_after_hooks(self): 42 | model = tests.mocks.models_other.SVM() 43 | model.fit(test=True, score=True) 44 | model.predict(model.pipeline.test_data) 45 | self.assertTrue(model.called_before_fit) 46 | self.assertTrue(model.called_after_fit) 47 | self.assertTrue(model.called_before_predict) 48 | self.assertTrue(model.called_after_predict) 49 | self.assertTrue(model.called_before_evaluate) 50 | self.assertTrue(model.called_after_evaluate) 51 | self.assertTrue(model.called_before_score) 52 | self.assertTrue(model.called_after_score) 53 | 54 | 55 | class TestOneHotBinaryClassifier(unittest.TestCase): 56 | def test_lifecycle(self): 57 | model = tests.mocks.models_other.OneHotBinaryClassifier() 58 | model.fit() 59 | model.save() 60 | 61 | loaded = tests.mocks.models_other.OneHotBinaryClassifier.load() 62 | self.assertEqual(loaded.fitting, model.fitting) 63 | 64 | def test_before_after_hooks(self): 65 | model = tests.mocks.models_other.OneHotBinaryClassifier() 66 | model.fit(test=True, score=True) 67 | model.predict(model.pipeline.test_data) 68 | 69 | self.assertTrue(model.called_before_fit) 70 | self.assertTrue(model.called_after_fit) 71 | self.assertTrue(model.called_before_predict) 72 | self.assertTrue(model.called_after_predict) 73 | self.assertTrue(model.called_before_evaluate) 74 | self.assertTrue(model.called_after_evaluate) 75 | self.assertTrue(model.called_before_score) 76 | self.assertTrue(model.called_after_score) 77 | 78 | 79 | class TestNaiveBinaryClassifier(unittest.TestCase): 80 | def test_lifecycle(self): 81 | model = tests.mocks.models_other.NaiveBinaryClassifier() 82 | model.fit() 83 | model.save() 84 | 85 | loaded = tests.mocks.models_other.NaiveBinaryClassifier.load() 86 | self.assertEqual(loaded.fitting, model.fitting) 87 | 88 | def test_before_after_hooks(self): 89 | model = tests.mocks.models_other.NaiveBinaryClassifier() 90 | model.fit(test=True, score=True) 91 | model.predict(model.pipeline.test_data) 92 | 93 | self.assertTrue(model.called_before_fit) 94 | self.assertTrue(model.called_after_fit) 95 | self.assertTrue(model.called_before_predict) 96 | self.assertTrue(model.called_after_predict) 97 | self.assertTrue(model.called_before_evaluate) 98 | self.assertTrue(model.called_after_evaluate) 99 | self.assertTrue(model.called_before_score) 100 | self.assertTrue(model.called_after_score) 101 | 102 | def test_preds(self): 103 | model = tests.mocks.models_other.NaiveBinaryClassifier() 104 | model.fit(test=True, score=True) 105 | preds = model.predict(model.pipeline.test_data) 106 | self.assertTrue((preds == 1).all()) 107 | 108 | def test_probs(self): 109 | model = tests.mocks.models_other.NaiveBinaryClassifier() 110 | model.fit(test=True, score=True) 111 | probs = model.predict_proba(model.pipeline.test_data)[:, 1] 112 | self.assertTrue((numpy.abs(probs - 0.667) < 0.001).all()) 113 | -------------------------------------------------------------------------------- /tests/unit/test_pipelines.py: -------------------------------------------------------------------------------- 1 | from __future__ import generators 2 | 3 | import unittest 4 | 5 | import pandas 6 | 7 | import tests.mocks.pipelines 8 | 9 | 10 | class TestTimedTrainTestSplit(unittest.TestCase): 11 | 12 | def test_time_series(self): 13 | mock = tests.mocks.pipelines.MockData(sort_by=None) 14 | self.assertEqual(mock.training_data['a'].max(), 8) 15 | self.assertEqual(mock.validation_data['a'].max(), 9) 16 | self.assertEqual(mock.test_data['a'].max(), 10) 17 | 18 | mock = tests.mocks.pipelines.MockData(sort_by='a') 19 | self.assertEqual(mock.training_data['a'].max(), 8) 20 | self.assertEqual(mock.validation_data['a'].max(), 9) 21 | self.assertEqual(mock.test_data['a'].max(), 10) 22 | 23 | mock = tests.mocks.pipelines.MockData(sort_by='b') 24 | self.assertEqual(mock.training_data['b'].max(), 28) 25 | self.assertEqual(mock.validation_data['b'].max(), 29) 26 | self.assertEqual(mock.test_data['b'].max(), 30) 27 | 28 | 29 | class TestLowMemory(unittest.TestCase): 30 | 31 | def setUp(self): 32 | self.dataframe = tests.mocks.pipelines.Users.dataframe 33 | self.pipeline = tests.mocks.pipelines.Users() 34 | 35 | def test_has_a_name(self): 36 | self.assertIsNotNone(self.pipeline.name) 37 | 38 | def test_columns(self): 39 | self.assertEqual(set(self.pipeline.columns), set(self.dataframe.columns)) 40 | 41 | def test_split(self): 42 | self.pipeline._split_data() 43 | self.assertEqual(self.pipeline.table_length(self.pipeline.table_training), 800) 44 | self.assertEqual(self.pipeline.table_length(self.pipeline.table_validation), 100) 45 | self.assertEqual(self.pipeline.table_length(self.pipeline.table_test), 100) 46 | 47 | def test_split_stratify(self): 48 | self.pipeline = tests.mocks.pipelines.Users() 49 | self.pipeline.stratify = 'last_name' 50 | 51 | data = pandas.concat([chunk for chunk in self.pipeline.training_data]) 52 | self.assertTrue(len(data['last_name'].drop_duplicates()), 80) 53 | 54 | data = pandas.concat([chunk for chunk in self.pipeline.validation_data]) 55 | self.assertTrue(len(data['last_name'].drop_duplicates()), 10) 56 | 57 | data = pandas.concat([chunk for chunk in self.pipeline.test_data]) 58 | self.assertTrue(len(data['last_name'].drop_duplicates()), 10) 59 | 60 | def test_split_subsample(self): 61 | self.pipeline = tests.mocks.pipelines.Users() 62 | self.pipeline.subsample = 50 63 | 64 | self.pipeline._split_data() 65 | self.assertEqual(self.pipeline.table_length(self.pipeline.table_training), 40) 66 | self.assertEqual(self.pipeline.table_length(self.pipeline.table_validation), 5) 67 | self.assertEqual(self.pipeline.table_length(self.pipeline.table_test), 5) 68 | 69 | def test_encoded_data(self): 70 | self.pipeline = tests.mocks.pipelines.Users() 71 | self.pipeline.subsample = 50 72 | self.pipeline.connection.execute('drop table if exists {name}'.format(name=self.pipeline.table_training + '_random')) 73 | self.assertEqual(len(pandas.concat([chunk.x for chunk in self.pipeline.encoded_training_data])), 40) 74 | self.assertEqual(len(pandas.concat([chunk.x for chunk in self.pipeline.encoded_validation_data])), 5) 75 | self.assertEqual(len(pandas.concat([chunk.x for chunk in self.pipeline.encoded_test_data])), 5) 76 | 77 | def test_encoded_data_twin(self): 78 | self.pipeline = tests.mocks.pipelines.TwinData() 79 | self.pipeline.subsample = 50 80 | self.assertEqual(len(self.pipeline.encoded_training_data[0]), 3) 81 | 82 | def test_preserves_types(self): 83 | self.pipeline = tests.mocks.pipelines.Users() 84 | training_data = pandas.concat([chunk for chunk in self.pipeline.training_data]) 85 | self.assertTrue(training_data['id'].dtype, 'integer') 86 | self.assertTrue(training_data['first_name'].dtype, 'object') 87 | self.assertTrue(training_data['last_name'].dtype, 'object') 88 | self.assertTrue(training_data['subscriber'].dtype, 'bool') 89 | self.assertTrue(training_data['signup_at'].dtype, 'datetime64[ns]') 90 | 91 | def test_generator(self): 92 | pipeline = tests.mocks.pipelines.Users() 93 | pipeline.stratify = 'last_name' 94 | chunks = 0 95 | length = 0 96 | for chunk in pipeline.generator(pipeline.table_training, orient='row', encoded=True, stratify=False, chunksize=200): 97 | chunks += 1 98 | length += len(chunk.x) 99 | self.assertEqual(chunks, 4) 100 | self.assertEqual(length, 800) 101 | 102 | chunks = 0 103 | length = 0 104 | for chunk in pipeline.generator(pipeline.table_training, orient='row', encoded=True, stratify=True, chunksize=10): 105 | chunks += 1 106 | length += len(chunk.x) 107 | self.assertEqual(chunks, 80) 108 | self.assertEqual(length, 800) 109 | 110 | chunks = 0 111 | length = 0 112 | for chunk in pipeline.generator(pipeline.table_training, orient='column', encoded=True): 113 | chunks += 1 114 | length += len(chunk.x) 115 | self.assertEqual(chunks, 5) 116 | self.assertEqual(length, 5 * 800) 117 | -------------------------------------------------------------------------------- /tests/unit/test_stores.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import unittest 3 | import os 4 | import time 5 | 6 | import lore 7 | from lore.stores import query_cached 8 | from lore.stores.disk import Disk 9 | 10 | 11 | class TestDisk(unittest.TestCase): 12 | def setUp(self): 13 | self.one_shot_calls = 0 14 | 15 | @query_cached 16 | def one_shot_function(self, **kwargs): 17 | if self.one_shot_calls > 0: 18 | raise "Can't call me twice" 19 | self.one_shot_calls += 1 20 | return self.one_shot_calls 21 | 22 | def test_disk(self): 23 | cache = Disk(os.path.join(lore.env.DATA_DIR, 'cache')) 24 | 25 | for key in cache.keys(): 26 | del cache[key] 27 | 28 | self.assertEqual(len(cache), 0) 29 | self.assertEqual(cache['a'], None) 30 | self.assertEqual(cache.keys(), []) 31 | 32 | cache['a'] = 1 33 | self.assertEqual(len(cache), 1) 34 | self.assertEqual(cache['a'], 1) 35 | self.assertEqual(cache.keys(), ['a']) 36 | 37 | time.sleep(1) # disk cache LRU has 1 second precision 38 | 39 | cache['b'] = 2 40 | self.assertEqual(len(cache), 2) 41 | self.assertEqual(cache.lru(), 'a') 42 | self.assertEqual(cache['b'], 2) 43 | self.assertEqual(sorted(cache.keys()), ['a', 'b']) 44 | 45 | cache['b'] = 3 46 | self.assertEqual(len(cache), 2) 47 | self.assertEqual(cache.lru(), 'a') 48 | self.assertEqual(cache['b'], 3) 49 | self.assertEqual(sorted(cache.keys()), ['a', 'b']) 50 | 51 | del cache['b'] 52 | self.assertEqual(len(cache), 1) 53 | self.assertEqual(cache.lru(), 'a') 54 | self.assertFalse('b' in cache) 55 | self.assertEqual(cache.keys(), ['a']) 56 | 57 | cache.limit = 0 58 | self.assertRaises(MemoryError, cache.__setitem__, 'a', 1) 59 | del cache['a'] 60 | 61 | self.assertEqual(len(cache), 0) 62 | self.assertEqual(cache.lru(), None) 63 | self.assertFalse('a' in cache) 64 | self.assertEqual(cache.keys(), []) 65 | 66 | def test_query_cached(self): 67 | cache = lore.stores.query_cache 68 | length = len(cache) 69 | now = datetime.datetime.now() 70 | 71 | # first copy is stored in the cache 72 | calls = self.one_shot_function(when=now, int=length, str='hi', cache=True) 73 | self.assertEqual(length + 1, len(cache)) 74 | self.assertEqual(1, calls) 75 | 76 | # second is retrieved 77 | calls = self.one_shot_function(when=now, int=length, str='hi', cache=True) 78 | self.assertEqual(length + 1, len(cache)) 79 | self.assertEqual(1, calls) 80 | self.assertEqual(1, self.one_shot_calls) 81 | 82 | -------------------------------------------------------------------------------- /tests/unit/test_transformers.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import datetime 3 | import math 4 | 5 | import numpy 6 | import pandas 7 | 8 | import lore.transformers 9 | 10 | 11 | class TestAreaCode(unittest.TestCase): 12 | def setUp(self): 13 | self.transformer = lore.transformers.AreaCode('phone') 14 | 15 | def test_phone_formats(self): 16 | values = pandas.DataFrame({ 17 | 'phone': [ 18 | '12345678901', 19 | '+12345678901', 20 | '1(234)567-8901', 21 | '1 (234) 567-8901', 22 | '1.234.567.8901', 23 | '1-234-567-8901', 24 | '2345678901', 25 | '234.567.8901', 26 | '(234)5678901', 27 | '(234) 567-8901', 28 | ] 29 | }) 30 | result = self.transformer.transform(values) 31 | self.assertEqual(result.tolist(), numpy.repeat('234', len(values)).tolist()) 32 | 33 | def test_bad_data(self): 34 | values = pandas.DataFrame({ 35 | 'phone': [ 36 | '1234567', 37 | '(123)4567', 38 | '', 39 | None, 40 | 12345678901, 41 | ] 42 | }) 43 | result = self.transformer.transform(values) 44 | self.assertEqual(result.tolist(), ['', '', '', None, '']) 45 | 46 | 47 | class TestEmailDomain(unittest.TestCase): 48 | def setUp(self): 49 | self.transformer = lore.transformers.EmailDomain('email') 50 | 51 | def test_transform(self): 52 | values = pandas.DataFrame({ 53 | 'email': [ 54 | 'montana@instacart.com', 55 | 'sue-bob+anne@instacart.com' 56 | ] 57 | }) 58 | result = self.transformer.transform(values) 59 | self.assertEqual(result.tolist(), numpy.repeat('instacart.com', len(values)).tolist()) 60 | 61 | 62 | class TestNameFamilial(unittest.TestCase): 63 | def setUp(self): 64 | self.transformer = lore.transformers.NameFamilial('name') 65 | 66 | def test_transform(self): 67 | values = pandas.DataFrame({ 68 | 'name': [ 69 | 'mom', 70 | 'Dad', 71 | 'sue bob' 72 | ] 73 | }) 74 | result = self.transformer.transform(values) 75 | self.assertEqual(result.tolist(), [True, True, False]) 76 | 77 | 78 | class TestDateTime(unittest.TestCase): 79 | def test_transform_day_of_week(self): 80 | transformer = lore.transformers.DateTime('test', 'dayofweek') 81 | data = pandas.DataFrame({'test': [datetime.datetime(2016, 12, 31), datetime.date(2017, 1, 1)]}) 82 | transformed = transformer.transform(data) 83 | self.assertEqual(transformed.iloc[0] + 1, transformed.iloc[1]) 84 | 85 | 86 | class TestAge(unittest.TestCase): 87 | def test_transform_age(self): 88 | transformer = lore.transformers.Age('test', unit='days') 89 | yesterday = datetime.datetime.now() - datetime.timedelta(days=1) 90 | 91 | data = pandas.DataFrame({'test': [datetime.datetime.now(), yesterday]}) 92 | transformed = transformer.transform(data) 93 | self.assertEqual(transformed.astype(int).tolist(), [0, 1]) 94 | 95 | 96 | class TestNameAge(unittest.TestCase): 97 | def test_transform_name(self): 98 | transformer = lore.transformers.NameAge('test') 99 | 100 | data = pandas.DataFrame({'test': ['bob', 'Bob']}) 101 | transformed = transformer.transform(data) 102 | self.assertTrue(transformed.iloc[0] > 0) 103 | self.assertEqual(transformed.iloc[0], transformed.iloc[1]) 104 | 105 | 106 | class TestNameSex(unittest.TestCase): 107 | def test_transform_name(self): 108 | transformer = lore.transformers.NameSex('test') 109 | 110 | data = pandas.DataFrame({'test': ['bob', 'Bob']}) 111 | transformed = transformer.transform(data) 112 | self.assertTrue(transformed.iloc[0] > 0) 113 | self.assertEqual(transformed.iloc[0], transformed.iloc[1]) 114 | 115 | 116 | class TestNamePopulation(unittest.TestCase): 117 | def test_transform_name(self): 118 | transformer = lore.transformers.NamePopulation('test') 119 | 120 | data = pandas.DataFrame({'test': ['bob', 'Bob']}) 121 | transformed = transformer.transform(data) 122 | self.assertTrue(transformed.iloc[0] > 0) 123 | self.assertEqual(transformed.iloc[0], transformed.iloc[1]) 124 | 125 | 126 | class TestStringLower(unittest.TestCase): 127 | def test_transform_name(self): 128 | transformer = lore.transformers.String('test', 'lower') 129 | 130 | data = pandas.DataFrame({'test': ['bob', 'Bob']}) 131 | transformed = transformer.transform(data) 132 | self.assertEqual(transformed.iloc[0], 'bob') 133 | self.assertEqual(transformed.iloc[1], 'bob') 134 | 135 | 136 | class TestGeoIP(unittest.TestCase): 137 | def test_transform_latitude(self): 138 | transformer = lore.transformers.GeoIP('test', 'latitude') 139 | 140 | data = pandas.DataFrame({'test': ['124.0.0.1', '124.0.0.2']}) 141 | transformed = transformer.transform(data) 142 | self.assertAlmostEqual(transformed.iloc[0], 37.5112) 143 | self.assertAlmostEqual(transformed.iloc[1], 37.5112) 144 | 145 | def test_transform_longitude(self): 146 | transformer = lore.transformers.GeoIP('test', 'longitude') 147 | 148 | data = pandas.DataFrame({'test': ['124.0.0.1', '124.0.0.2']}) 149 | transformed = transformer.transform(data) 150 | self.assertAlmostEqual(transformed.iloc[0], 126.9741) 151 | self.assertAlmostEqual(transformed.iloc[1], 126.9741) 152 | 153 | def test_transform_accuracy(self): 154 | transformer = lore.transformers.GeoIP('test', 'accuracy') 155 | 156 | data = pandas.DataFrame({'test': ['124.0.0.1', '124.0.0.2']}) 157 | transformed = transformer.transform(data) 158 | self.assertEqual(transformed.iloc[0], 200) 159 | self.assertEqual(transformed.iloc[1], 200) 160 | 161 | def test_missing_ip(self): 162 | transformer = lore.transformers.GeoIP('test', 'accuracy') 163 | data = pandas.DataFrame({'test': ['127.0.0.2']}) 164 | transformed = transformer.transform(data) 165 | self.assertTrue(math.isnan(transformed.iloc[0])) 166 | 167 | 168 | class TestDistance(unittest.TestCase): 169 | def test_distance(self): 170 | data = pandas.DataFrame({ 171 | 'a_lat': [0., 52.2296756], 172 | 'b_lat': [0., 52.406374], 173 | 'a_lon': [0., 21.0122287], 174 | 'b_lon': [0., 16.9251681] 175 | }) 176 | 177 | transformer = lore.transformers.Distance( 178 | lat_a='a_lat', 179 | lat_b='b_lat', 180 | lon_a='a_lon', 181 | lon_b='b_lon', 182 | ) 183 | 184 | transformed = transformer.transform(data) 185 | self.assertEqual(transformed.iloc[0], 0) 186 | self.assertEqual(transformed.iloc[1], 278.54558935106695) 187 | 188 | def test_ip(self): 189 | data = pandas.DataFrame({'a': ['124.0.0.1', '124.0.0.2'], 'b': ['124.0.0.1', '127.0.0.2']}) 190 | 191 | transformer = lore.transformers.Distance( 192 | lat_a=lore.transformers.GeoIP('a', 'latitude'), 193 | lat_b=lore.transformers.GeoIP('b', 'latitude'), 194 | lon_a=lore.transformers.GeoIP('a', 'longitude'), 195 | lon_b=lore.transformers.GeoIP('b', 'longitude'), 196 | ) 197 | 198 | transformed = transformer.transform(data) 199 | self.assertEqual(transformed.iloc[0], 0) 200 | self.assertTrue(math.isnan(transformed.iloc[1])) 201 | --------------------------------------------------------------------------------