├── .gitignore ├── .travis.yml ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── doc ├── Makefile ├── code.rst ├── conf.py ├── design │ ├── client_design_doc.md │ └── magic_design_doc.md └── index.rst ├── index.html ├── proto └── sqlflow │ └── proto │ └── sqlflow.proto ├── setup.cfg ├── setup.py ├── sqlflow ├── __init__.py ├── __main__.py ├── _version.py ├── client.py ├── compound_message.py ├── env_expand.py ├── magic.py └── rows.py └── tests ├── __init__.py ├── mock_servicer.py ├── test_client.py ├── test_env_expand.py └── test_version.py /.gitignore: -------------------------------------------------------------------------------- 1 | sqlflow/proto 2 | 3 | venv/ 4 | build/ 5 | dist/ 6 | 7 | .eggs/ 8 | *.egg-info/ 9 | .pytest_cache 10 | __pycache__/ 11 | 12 | .idea/ 13 | 14 | *.swp 15 | *.vim 16 | *.pyc 17 | *.log 18 | 19 | .DS_Store 20 | .vscode 21 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: xenial 2 | cache: pip 3 | branches: 4 | only: 5 | - master 6 | - develop 7 | # https://docs.travis-ci.com/user/customizing-the-build/#safelisting-or-blocklisting-branches 8 | # safe list can prevent tag building, add rexp to detect tag 9 | - "/^v\\d+\\.\\d+(\\.\\d+)?(-\\S*)?$/" 10 | 11 | language: python 12 | python: 13 | - 3.6 14 | - 3.7 15 | 16 | install: 17 | - python setup.py -q install 18 | 19 | script: 20 | - make protoc 21 | - python setup.py -q test 22 | 23 | deploy: 24 | provider: pypi 25 | user: tonyyang 26 | password: 27 | # secure token: `echo -n "password" | travis encrypt --add deploy.password -r sql-machine-learning/pysqlflow` 28 | secure: u8hoMuMukUcD6D+y+61w93yczgcixTgj8wQadv0VrI7aEsIOkMi508FXhsj4gmgN95O7fUApi6oXWhmOllVGFy2wJ+zZH1M8WDZi6Gq23meWTs7K+f9m6dADN0+/Y87Nc/UIOQLO8bU38hqophDRbC5Utn5bY5/lkMZzg8ypG9luZ+div+6DmZ3+kxh8o7jByiql0JjBeVLfycVW+LUpu6Fqx0XIp+tyWx99osT9LxDoZyPOmN4yXiyhEWNaK9BHi2X+qxxymjZvEJ0nSKeyQmfxBoKFW0PXh1t6r2yP5h1FzeuvZ3jB4hyvYAplKJbWJ6RLNISpBitB4PeeIVTFBPEzjc5gtDPvc5oy/OoLtmtbxGpAYh/J+DOA5T9S3vfM38xLByGMgA/Tpj0521RVd/UeyLgOrw7yPR1ZdyfDILoJ+TdupSNBpF5uMemuQGoGprGuMMHhZC+9Yv3+pUitnugDBHCx/A5uWSh699fv7UQBH0AgG5tFroQLZO6M1UVWyhkY1+LJZTuW8N4Xgq3d5dh6KdTaAJ7HZrAjow1Vc9U0S6RlWyhRgnuDwcVs3O2sPjufRSo1NK+P8Y78Gy10L1vtIr8MJRtTkTqcztrBZWAwjcvCHlgNfOPpPW1wR6klfBYN2/0MFFUYlo2ZTbsQSIgSLcVvR9d+TZkWaCsImPg= 29 | distributions: sdist bdist_wheel 30 | skip_existing: true 31 | skip_cleanup: true 32 | on: 33 | tags: true 34 | repo: sql-machine-learning/pysqlflow 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md LICENSE 2 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SQLFLOW_VERSION := develop 2 | VERSION_FILE := sqlflow/_version.py 3 | SHELL := /bin/bash 4 | 5 | setup: ## Setup virtual environment for local development 6 | python3 -m venv venv 7 | source venv/bin/activate \ 8 | && $(MAKE) install-requirements protoc 9 | 10 | install-requirements: 11 | pip install -e . 12 | 13 | test: ## Run tests 14 | python3 setup.py test 15 | 16 | clean: ## Clean up temporary folders 17 | rm -rf build dist .eggs *.egg-info .pytest_cache sqlflow/proto 18 | 19 | # NOTE(typhoonzero): grpcio-tools >= 1.29.0 breaks the tests like: 20 | # AttributeError: module 'google.protobuf.descriptor' has no attribute '_internal_create_key' 21 | protoc: ## Generate python client from proto file 22 | python3 -m venv build/grpc 23 | source build/grpc/bin/activate \ 24 | && pip install grpcio-tools==1.29.0 \ 25 | && mkdir -p build/grpc/sqlflow/proto \ 26 | && python -m grpc_tools.protoc -Iproto --python_out=. \ 27 | --grpc_python_out=. proto/sqlflow/proto/sqlflow.proto 28 | 29 | release: ## Release new version 30 | $(if $(shell git status -s), $(error "Please commit your changes or stash them before you release.")) 31 | 32 | # Make sure local develop branch is up-to-date 33 | git fetch origin 34 | git checkout develop 35 | git merge origin/develop 36 | 37 | # Remove dev from version number 38 | $(eval VERSION := $(subst .dev,,$(shell python -c "exec(open('$(VERSION_FILE)').read());print(__version__)"))) 39 | $(info release $(VERSION)...) 40 | sed -i '' "s/, 'dev'//" $(VERSION_FILE) 41 | git commit -a -m "release $(VERSION)" 42 | 43 | # Tag it 44 | git tag v$(VERSION) 45 | 46 | # Bump version for development 47 | $(eval NEXT_VERSION := $(shell echo $(VERSION) | awk -F. '{print $$1"."($$2+1)".0"}')) 48 | $(eval VERSION_CODE := $(shell echo $(NEXT_VERSION) | sed 's/\./, /g')) 49 | sed -i '' -E "s/[0-9]+, [0-9]+, [0-9]+/$(VERSION_CODE), 'dev'/" $(VERSION_FILE) 50 | git commit -a -m "start $(NEXT_VERSION)" 51 | git push origin develop 52 | 53 | # Push the tag, release the package to pypi 54 | git push --tags 55 | 56 | doc: 57 | $(MAKE) setup \ 58 | && source venv/bin/activate \ 59 | && pip install sphinx \ 60 | && cd doc \ 61 | && make clean && make html 62 | 63 | help: 64 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 65 | 66 | .PHONY: help doc 67 | .DEFAULT_GOAL := help 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sqlflow [![Build Status](https://travis-ci.org/sql-machine-learning/pysqlflow.svg?branch=develop)](https://travis-ci.org/sql-machine-learning/pysqlflow) [![PyPI Package](https://img.shields.io/pypi/v/sqlflow.svg)](https://pypi.python.org/pypi/sqlflow) 2 | 3 | [SQLFlow](https://github.com/sql-machine-learning/sqlflow) client library for Python. 4 | 5 | ## Installation 6 | 7 | This package is available on PyPI as `pysqlflow`. So you can install it by running the following command: 8 | 9 | pip install sqlflow 10 | 11 | ## Documentation 12 | 13 | You can read the Sphinx generated docs at: 14 | [http://sql-machine-learning.github.io/pysqlflow/](http://sql-machine-learning.github.io/pysqlflow/) 15 | 16 | ## Development 17 | 18 | ## Prerequisite 19 | ### Python 3 20 | `brew install python` 21 | 22 | ### Setup Environment 23 | `make setup` 24 | 25 | ### Test 26 | `make test` 27 | 28 | ### Release 29 | `make release` 30 | 31 | ### Generate Documentation 32 | `make doc` 33 | 34 | ### Generate GRPC client 35 | GRPC client code has been generated when you setup environment. 36 | If you would like to regenerate it, please run `make protoc`. 37 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /doc/code.rst: -------------------------------------------------------------------------------- 1 | Python Client for SQLFlow Server 2 | ================================ 3 | 4 | .. automodule:: sqlflow.client 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | .. automodule:: sqlflow.magic 10 | :members: 11 | :undoc-members: 12 | :show-inheritance: 13 | 14 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | sys.path.insert(0, os.path.abspath('../')) 18 | 19 | # enable auto generate doc on __init__ 20 | autoclass_content = 'both' 21 | 22 | 23 | 24 | # -- Project information ----------------------------------------------------- 25 | 26 | # TODO(tony): read project, author, version, release from setup.py 27 | project = 'pysqlflow' 28 | copyright = '2019, Yang Yang' 29 | author = 'Yang Yang' 30 | 31 | # The short X.Y version 32 | version = '' 33 | # The full version, including alpha/beta/rc tags 34 | release = '' 35 | 36 | 37 | # -- General configuration --------------------------------------------------- 38 | 39 | # If your documentation needs a minimal Sphinx version, state it here. 40 | # 41 | # needs_sphinx = '1.0' 42 | 43 | # Add any Sphinx extension module names here, as strings. They can be 44 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 45 | # ones. 46 | extensions = [ 47 | 'sphinx.ext.autodoc', 48 | 'sphinx.ext.githubpages', 49 | ] 50 | 51 | # Add any paths that contain templates here, relative to this directory. 52 | templates_path = ['_templates'] 53 | 54 | # The suffix(es) of source filenames. 55 | # You can specify multiple suffix as a list of string: 56 | # 57 | # source_suffix = ['.rst', '.md'] 58 | source_suffix = '.rst' 59 | 60 | # The master toctree document. 61 | master_doc = 'index' 62 | 63 | # The language for content autogenerated by Sphinx. Refer to documentation 64 | # for a list of supported languages. 65 | # 66 | # This is also used if you do content translation via gettext catalogs. 67 | # Usually you set "language" from the command line for these cases. 68 | language = None 69 | 70 | # List of patterns, relative to source directory, that match files and 71 | # directories to ignore when looking for source files. 72 | # This pattern also affects html_static_path and html_extra_path. 73 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 74 | 75 | # The name of the Pygments (syntax highlighting) style to use. 76 | pygments_style = None 77 | 78 | 79 | # -- Options for HTML output ------------------------------------------------- 80 | 81 | # The theme to use for HTML and HTML Help pages. See the documentation for 82 | # a list of builtin themes. 83 | # 84 | html_theme = 'alabaster' 85 | 86 | # Theme options are theme-specific and customize the look and feel of a theme 87 | # further. For a list of options available for each theme, see the 88 | # documentation. 89 | # 90 | # html_theme_options = {} 91 | 92 | # Add any paths that contain custom static files (such as style sheets) here, 93 | # relative to this directory. They are copied after the builtin static files, 94 | # so a file named "default.css" will overwrite the builtin "default.css". 95 | html_static_path = ['_static'] 96 | 97 | # Custom sidebar templates, must be a dictionary that maps document names 98 | # to template names. 99 | # 100 | # The default sidebars (for documents that don't match any pattern) are 101 | # defined by theme itself. Builtin themes are using these templates by 102 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 103 | # 'searchbox.html']``. 104 | # 105 | # html_sidebars = {} 106 | 107 | 108 | # -- Options for HTMLHelp output --------------------------------------------- 109 | 110 | # Output file base name for HTML help builder. 111 | htmlhelp_basename = 'pysqlflowdoc' 112 | 113 | 114 | # -- Options for LaTeX output ------------------------------------------------ 115 | 116 | latex_elements = { 117 | # The paper size ('letterpaper' or 'a4paper'). 118 | # 119 | # 'papersize': 'letterpaper', 120 | 121 | # The font size ('10pt', '11pt' or '12pt'). 122 | # 123 | # 'pointsize': '10pt', 124 | 125 | # Additional stuff for the LaTeX preamble. 126 | # 127 | # 'preamble': '', 128 | 129 | # Latex figure (float) alignment 130 | # 131 | # 'figure_align': 'htbp', 132 | } 133 | 134 | # Grouping the document tree into LaTeX files. List of tuples 135 | # (source start file, target name, title, 136 | # author, documentclass [howto, manual, or own class]). 137 | latex_documents = [ 138 | (master_doc, 'pysqlflow.tex', 'pysqlflow Documentation', 139 | 'Yang Yang', 'manual'), 140 | ] 141 | 142 | 143 | # -- Options for manual page output ------------------------------------------ 144 | 145 | # One entry per manual page. List of tuples 146 | # (source start file, name, description, authors, manual section). 147 | man_pages = [ 148 | (master_doc, 'pysqlflow', 'pysqlflow Documentation', 149 | [author], 1) 150 | ] 151 | 152 | 153 | # -- Options for Texinfo output ---------------------------------------------- 154 | 155 | # Grouping the document tree into Texinfo files. List of tuples 156 | # (source start file, target name, title, author, 157 | # dir menu entry, description, category) 158 | texinfo_documents = [ 159 | (master_doc, 'pysqlflow', 'pysqlflow Documentation', 160 | author, 'pysqlflow', 'One line description of project.', 161 | 'Miscellaneous'), 162 | ] 163 | 164 | 165 | # -- Options for Epub output ------------------------------------------------- 166 | 167 | # Bibliographic Dublin Core info. 168 | epub_title = project 169 | 170 | # The unique identifier of the text. This can be a ISBN number 171 | # or the project homepage. 172 | # 173 | # epub_identifier = '' 174 | 175 | # A unique identification for the text. 176 | # 177 | # epub_uid = '' 178 | 179 | # A list of files that should not be packed into the epub file. 180 | epub_exclude_files = ['search.html'] 181 | 182 | 183 | # -- Extension configuration ------------------------------------------------- 184 | -------------------------------------------------------------------------------- /doc/design/client_design_doc.md: -------------------------------------------------------------------------------- 1 | # SQLFlow Client Design Doc 2 | 3 | ## Overview 4 | 5 | SQLFlow Client connects [sqlflowserver](https://github.com/sql-machine-learning/sqlflowserver). 6 | It only one method `Run` which takes a SQL statement and returns a `RowSet` object. 7 | 8 | ## Example 9 | 10 | ```python 11 | import sqlflow 12 | 13 | client = sqlflow.Client(server_url='localhost:50051') 14 | 15 | # Query SQL 16 | rowset = client.run('SELECT ... FROM ...') 17 | for row in rowset: 18 | print(row) # [1, 1] 19 | 20 | # Execution SQL, prints 21 | # Query OK, ... row affected (... sec) 22 | client.run('DELETE FROM ... WHERE ...') 23 | 24 | # ML SQL, prints 25 | # epoch = 0, loss = ... 26 | # epoch = 1, loss = ... 27 | # ... 28 | client.run('SELECT ... TRAIN ...') 29 | ``` 30 | 31 | ## Service Protocol 32 | 33 | `sqlflow.Client` uses grpc to contact the `sqlflowserver`. The service protocol 34 | is defined [here](sqlfow/proto/sqlflow.proto) 35 | 36 | ## Implementaion 37 | 38 | `sqlflow.Client.__init__` establishes a grpc stub/channel based on `server_url`. 39 | 40 | `sqlflow.Client.run` takes a sql statement and returns a `RowSet` object. 41 | ```python 42 | class Client: 43 | def __init__(self, host): 44 | channel = grpc.insecure_channel(host) 45 | self._stub = sqlflow_pb2_grpc.SQLFlowStub(channel) 46 | 47 | def _decode_protobuf(self, proto): 48 | '''decode rowset''' 49 | 50 | def run(self, operation): 51 | def rowset_gen(): 52 | for res in self._stub.Run(sqlflow_pb2.Request(sql=operation)): 53 | if res.is_message(): 54 | log(res) 55 | else: 56 | yield self._decode_protobuf(res) 57 | 58 | return RowSet(rowset_gen=rowset_gen) 59 | 60 | class RowSet: 61 | def __init__(self, rowset_gen): 62 | res = [r for r in rowset_gen] 63 | if res: 64 | self._head = res[0] 65 | self._rows = res[1:] 66 | else: 67 | self._head, self._rows = None, None 68 | 69 | def __repr__(self): 70 | '''used for IPython: pretty prints self''' 71 | 72 | def rows(self): 73 | return self._rows 74 | 75 | def to_dataframe(self): 76 | '''convert to dataframes''' 77 | ``` 78 | 79 | ## Pagination 80 | 81 | Currently sqlflow server doesn't support pagination, neither does the client. 82 | If we want to support it in the future, it can be implemented through passing 83 | pageTokens. For example, the following code snippet from 84 | [google-api-go-client](https://github.com/googleapis/google-api-go-client/blob/master/iterator/iterator.go#L68) 85 | 86 | ```go 87 | // Function that fetches a page from the underlying service. It should pass 88 | // the pageSize and pageToken arguments to the service, fill the buffer 89 | // with the results from the call, and return the next-page token returned 90 | // by the service. The function must not remove any existing items from the 91 | // buffer. If the underlying RPC takes an int32 page size, pageSize should 92 | // be silently truncated. 93 | fetch func(pageSize int, pageToken string) (nextPageToken string, err error) 94 | ``` 95 | 96 | ## Credential 97 | 98 | The authorization between the client and the server should be independent 99 | with authorization between the server and the database. A client should never 100 | store any sensitive data such as DB username and password. 101 | -------------------------------------------------------------------------------- /doc/design/magic_design_doc.md: -------------------------------------------------------------------------------- 1 | # Magic Command Design Doc 2 | 3 | ## Overview 4 | 5 | Magic command wraps a pysqlflow client. It pretty prints the logs and tables. 6 | 7 | ## Example 8 | 9 | Magic command can be invoked through `%%sqlflow`. 10 | 11 | ``` 12 | In [1]: %load_ext sqlflow 13 | 14 | In [2]: %%sqlflow select * from iris.iris limit 1; 15 | Out[2]: executing query 100% [========================================] 16 | +--------------+-------------+--------------+-------------+-------+ 17 | | sepal_length | sepal_width | petal_length | petal_width | class | 18 | +--------------+-------------+--------------+-------------+-------+ 19 | | 6.4 | 2.8 | 5.6 | 2.2 | 2 | 20 | +--------------+-------------+--------------+-------------+-------+ 21 | 22 | In [3]: %%sqlflow select * 23 | ...: from iris.iris limit 1; 24 | ...: 25 | Out[3]: executing query 100% [========================================] 26 | +--------------+-------------+--------------+-------------+-------+ 27 | | sepal_length | sepal_width | petal_length | petal_width | class | 28 | +--------------+-------------+--------------+-------------+-------+ 29 | | 6.4 | 2.8 | 5.6 | 2.2 | 2 | 30 | +--------------+-------------+--------------+-------------+-------+ 31 | 32 | In [4]: %%sqlflow SELECT * 33 | ...: FROM iris.iris limit 1 34 | ...: TRAIN DNNClassifier 35 | ...: WITH 36 | ...: n_classes = 3, 37 | ...: hidden_units = [10, 20] 38 | ...: COLUMN sepal_length, sepal_width, petal_length, petal_width 39 | ...: LABEL class 40 | ...: INTO my_dnn_model; 41 | Out[4]: 42 | Epoch 0: Training Accuracy ... Validation Accuracy ... 43 | Epoch 1: Training Accuracy ... Validation Accuracy ... 44 | Epoch 2: Training Accuracy ... Validation Accuracy ... 45 | ... 46 | Train Finished. Model saved at my_dnn_model 47 | ``` 48 | 49 | ## Implementation 50 | 51 | ### Pretty print 52 | 53 | #### Table 54 | 55 | Some off-the-shelf library: https://stackoverflow.com/a/26937531/6794675 56 | 57 | ``` 58 | >>> from prettytable import PrettyTable 59 | >>> t = PrettyTable(['Name', 'Age']) 60 | >>> t.add_row(['Alice', 24]) 61 | >>> t.add_row(['Bob', 19]) 62 | >>> print t 63 | +-------+-----+ 64 | | Name | Age | 65 | +-------+-----+ 66 | | Alice | 24 | 67 | | Bob | 19 | 68 | +-------+-----+ 69 | ``` 70 | 71 | #### Log 72 | 73 | Progress bar: https://stackoverflow.com/questions/3002085/python-to-print-out-status-bar-and-percentage 74 | 75 | ```text 76 | [================ ] 60% 77 | ``` 78 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | .. pysqlflow documentation master file, created by 2 | sphinx-quickstart on Mon Feb 25 15:53:09 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to pysqlflow's documentation! 7 | ===================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | code 14 | 15 | Indices and tables 16 | ================== 17 | 18 | * :ref:`genindex` 19 | * :ref:`modindex` 20 | * :ref:`search` 21 | -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 8 | 9 | 10 |

Redirecting to https://github.com/sql-machine-learning/pysqlflow

11 | 12 | -------------------------------------------------------------------------------- /proto/sqlflow/proto/sqlflow.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "google/protobuf/any.proto"; 4 | 5 | package proto; 6 | 7 | service SQLFlow { 8 | // Run executes a sql statement 9 | // 10 | // SQL statements like `SELECT ...`, `DESCRIBE ...` returns a rowset. 11 | // The rowset might be big. In such cases, Query returns a stream 12 | // of RunResponse 13 | // 14 | // SQLFlow implements the Run interface with two mode: 15 | // 16 | // 1. Local model 17 | // The SQLFlow server execute the SQL statements on the local host. 18 | // 19 | // SQL statements like `USE database`, `DELETE` returns only a success 20 | // message. 21 | // 22 | // SQL statement like `SELECT ... TO TRAIN/PREDICT ...` returns a stream of 23 | // messages which indicates the training/predicting progress 24 | // 25 | // 2. Argo Workflow mode 26 | // The SQLFlow server submits an Argo workflow into a Kubernetes cluster, 27 | // and returns a stream of messages indicates the WorkFlow ID and the 28 | // submitting progress. 29 | // 30 | // The SQLFlow gRPC client should fetch the logs of the workflow by 31 | // calling the Fetch interface in a polling manner. 32 | rpc Run (Request) returns (stream Response); 33 | 34 | // Fetch fetches the SQLFlow job phase and logs in a polling manner. A corresponding 35 | // client can be implemented as 36 | // 37 | // wfJob := Submit(argoYAML) 38 | // req := &FetchRequest { 39 | // Job : { Id: wfJob }, 40 | // } 41 | // for { 42 | // res := Fetch(req) 43 | // fmt.Println(res.Logs) 44 | // if isComplete(res) { 45 | // break 46 | // } 47 | // req = res.UpdatedFetchSince 48 | // time.Sleep(time.Second) 49 | // } 50 | // 51 | rpc Fetch (FetchRequest) returns (FetchResponse); 52 | } 53 | 54 | message Job { 55 | string id = 1; 56 | string namespace = 2; 57 | } 58 | 59 | message FetchRequest { 60 | Job job = 1; 61 | // the following fields keep the fetching state 62 | string step_id = 2; 63 | string step_phase = 3; 64 | } 65 | 66 | message FetchResponse { 67 | message Responses { 68 | repeated Response response = 1; 69 | } 70 | FetchRequest updated_fetch_since = 1; 71 | bool eof = 2; 72 | Responses responses = 4; 73 | } 74 | 75 | message Session { 76 | string token = 1; 77 | string db_conn_str = 2; 78 | bool exit_on_submit = 3; 79 | string user_id = 4; 80 | // for loading CSV to hive 81 | string hive_location = 5; 82 | string hdfs_namenode_addr = 6; 83 | string hdfs_user = 7; 84 | string hdfs_pass = 8; 85 | string submitter = 9; 86 | // for rbac 87 | string service_account = 10; 88 | string wf_namespace = 11; 89 | } 90 | 91 | // SQL statements to run 92 | // e.g. 93 | // 1. `SELECT ...` 94 | // 2. `USE ...`, `DELETE ...` 95 | // 3. `SELECT ... TO TRAIN/PREDICT ...` 96 | message Request { 97 | string sql = 1; // The SQL statement to be executed. 98 | Session session = 2; 99 | } 100 | 101 | message Response { 102 | oneof response { 103 | Head head = 1; 104 | Row row = 2; 105 | Message message = 3; 106 | EndOfExecution eoe = 4; 107 | Job job = 5; 108 | } 109 | } 110 | 111 | // SQL statements like `SELECT ...`, `DESCRIBE ...` returns a Head 112 | // and a sequence of Rows 113 | message Head { 114 | repeated string column_names = 1; 115 | } 116 | 117 | message Row { 118 | // Null is a special marker used in Structured Query Language to indicate 119 | // that a data value does not exist in the database. 120 | // We encoded this marker as message Null, and it is one possible type of 121 | // google.protobuf.Any in the field data 122 | message Null {} 123 | repeated google.protobuf.Any data = 1; 124 | } 125 | 126 | // SQL statements like `USE database`, `DELETE` returns only a success 127 | // message. 128 | // 129 | // SQL statement like `SELECT ... TO TRAIN/PREDICT ...` returns a stream of 130 | // messages which indicates the training/predicting progress 131 | message Message { 132 | string message = 1; 133 | } 134 | 135 | // SQLFlow server may execute multiple SQL statements in one RPC call. 136 | // EndOfExecution message tells the client that execution of one SQL is 137 | // finished, the client should go to next loop to parse the result stream. 138 | message EndOfExecution { 139 | string sql = 1; 140 | int64 spent_time_seconds = 2; 141 | } 142 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [tool:pytest] 5 | rootdir=tests -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import io 5 | import os 6 | 7 | from setuptools import find_packages, setup 8 | 9 | # Package meta-data. 10 | NAME = 'sqlflow' 11 | DESCRIPTION = 'SQLFlow client library for Python.' 12 | URL = 'https://github.com/sql-machine-learning/sqlflow' 13 | EMAIL = 'kuisong.tong@gmail.com' 14 | AUTHOR = 'Kuisong Tong' 15 | REQUIRES_PYTHON = '>=3.5.0' 16 | VERSION = None 17 | 18 | # What packages are required for this module to be executed? 19 | REQUIRED = [ 20 | 'protobuf==3.7.1', 21 | 'grpcio >=1.17, <2', 22 | 'ipython==7.9', 23 | 'pandas', 24 | 'tornado>=6.0.0', 25 | 'python-dotenv>0.10.0', 26 | 'nest_asyncio>=1.4.0' 27 | ] 28 | SETUP_REQUIRED = [ 29 | 'pytest-runner' 30 | ] 31 | TEST_REQUIRED = [ 32 | 'pytest', 33 | ] 34 | 35 | # What packages are optional? 36 | EXTRAS = { 37 | } 38 | 39 | # The rest you shouldn't have to touch too much :) 40 | # ------------------------------------------------ 41 | # Except, perhaps the License and Trove Classifiers! 42 | # If you do change the License, remember to change the Trove Classifier for that! 43 | 44 | here = os.path.abspath(os.path.dirname(__file__)) 45 | 46 | # Import the README and use it as the long-description. 47 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 48 | try: 49 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 50 | long_description = '\n' + f.read() 51 | except FileNotFoundError: 52 | long_description = DESCRIPTION 53 | 54 | # Load the package's _version.py module as a dictionary. 55 | about = {} 56 | if not VERSION: 57 | with open(os.path.join(here, NAME, '_version.py')) as f: 58 | exec(f.read(), about) 59 | else: 60 | about['__version__'] = VERSION 61 | 62 | # Where the magic happens: 63 | setup( 64 | name=NAME, 65 | version=about['__version__'], 66 | description=DESCRIPTION, 67 | long_description=long_description, 68 | long_description_content_type='text/markdown', 69 | author=AUTHOR, 70 | author_email=EMAIL, 71 | python_requires=REQUIRES_PYTHON, 72 | url=URL, 73 | packages=find_packages(exclude=('tests',)), 74 | package_data={'sqlflow': ['proto/*.py']}, 75 | entry_points={ 76 | 'console_scripts': ['sqlflow = sqlflow.__main__:main'], 77 | }, 78 | install_requires=REQUIRED, 79 | setup_requires=SETUP_REQUIRED, 80 | tests_require=TEST_REQUIRED, 81 | extras_require=EXTRAS, 82 | license='Apache License 2.0', 83 | classifiers=[ 84 | # Trove classifiers 85 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 86 | 'License :: OSI Approved :: Apache Software License', 87 | 'Programming Language :: Python', 88 | 'Programming Language :: Python :: 3 :: Only', 89 | 'Programming Language :: Python :: 3.5', 90 | 'Programming Language :: Python :: 3.6', 91 | 'Programming Language :: Python :: 3.7', 92 | 'Programming Language :: Python :: Implementation :: CPython', 93 | 'Programming Language :: Python :: Implementation :: PyPy' 94 | ], 95 | ) 96 | -------------------------------------------------------------------------------- /sqlflow/__init__.py: -------------------------------------------------------------------------------- 1 | from .client import * 2 | from .env_expand import * 3 | from ._version import __version__ 4 | -------------------------------------------------------------------------------- /sqlflow/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from sqlflow.client import Client 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('sql', nargs='+', type=str, help="sql", action="store") 7 | parser.add_argument("--url", type=str, help="server url", action="store", default=None) 8 | parser.add_argument("--ca_crt", type=str, help="Path to CA certificates of SQLFlow client.", action="store", default=None) 9 | 10 | def main(): 11 | args = parser.parse_args() 12 | 13 | client = Client(server_url=args.url, ca_crt=args.ca_crt) 14 | for sql in args.sql: 15 | print("executing: {}".format(sql)) 16 | print(client.execute(sql)) 17 | -------------------------------------------------------------------------------- /sqlflow/_version.py: -------------------------------------------------------------------------------- 1 | VERSION = (0, 16, 0, 'dev') 2 | 3 | __version__ = '.'.join(map(str, VERSION)) 4 | -------------------------------------------------------------------------------- /sqlflow/client.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | import grpc 5 | import re 6 | import time 7 | 8 | import google.protobuf.wrappers_pb2 as wrapper 9 | from google.protobuf.timestamp_pb2 import Timestamp 10 | 11 | import sqlflow.proto.sqlflow_pb2 as pb 12 | import sqlflow.proto.sqlflow_pb2_grpc as pb_grpc 13 | 14 | from sqlflow.env_expand import EnvExpander, EnvExpanderError 15 | from sqlflow.rows import Rows 16 | from sqlflow.compound_message import CompoundMessage 17 | 18 | _LOGGER = logging.getLogger(__name__) 19 | handler = logging.StreamHandler(sys.stdout) 20 | _LOGGER.setLevel(logging.INFO) 21 | _LOGGER.addHandler(handler) 22 | # default timeout is 10 hours to tolerate waiting training 23 | # jobs to finish. 24 | DEFAULT_TIMEOUT=3600 * 10 25 | 26 | # HTML code prefix from step logs 27 | HTML_PREFIX="data:text/html," 28 | 29 | 30 | class StreamReader(object): 31 | def __init__(self, stream_response): 32 | self._stream_response = stream_response 33 | self.last_response = None 34 | 35 | def read_one(self): 36 | try: 37 | res = next(self._stream_response) 38 | except StopIteration: 39 | return (None, None) 40 | 41 | return (res, res.WhichOneof('response')) 42 | 43 | def read_until_type_changed(self): 44 | first_rtype = None 45 | while True: 46 | try: 47 | response = next(self._stream_response) 48 | except StopIteration: 49 | break 50 | 51 | rtype = response.WhichOneof('response') 52 | if first_rtype == None: 53 | first_rtype = rtype 54 | 55 | if first_rtype != rtype: 56 | self.last_response = response 57 | break 58 | yield response 59 | 60 | class Client: 61 | def __init__(self, server_url=None, ca_crt=None): 62 | """A minimum client that issues queries to and fetch results/logs from sqlflowserver. 63 | 64 | :param server_url: sqlflowserver url. If None, read value from 65 | environment variable SQLFLOW_SERVER. 66 | :type server_url: str. 67 | 68 | :param ca_crt: Path to CA certificates of SQLFlow client, if None, 69 | try to find the file from the environment variable: 70 | SQLFLOW_CA_CRT, otherwise using insecure client. 71 | :type ca_crt: str. 72 | 73 | :raises: ValueError 74 | 75 | Example: 76 | >>> client = sqlflow.Client(server_url="localhost:50051") 77 | 78 | """ 79 | if server_url is None: 80 | if "SQLFLOW_SERVER" not in os.environ: 81 | raise ValueError("Can't find environment variable SQLFLOW_SERVER") 82 | server_url = os.environ["SQLFLOW_SERVER"] 83 | 84 | self._stub = pb_grpc.SQLFlowStub(self.new_rpc_channel(server_url, ca_crt)) 85 | self._expander = EnvExpander(os.environ) 86 | 87 | def new_rpc_channel(self, server_url, ca_crt): 88 | if ca_crt is None and "SQLFLOW_CA_CRT" not in os.environ: 89 | # client would connect SQLFLow gRPC server with insecure mode. 90 | channel = grpc.insecure_channel(server_url) 91 | else: 92 | if ca_crt is None: 93 | ca_crt = os.environ["SQLFLOW_CA_CRT"] 94 | with open(ca_crt, "rb") as f: 95 | creds = grpc.ssl_channel_credentials(f.read()) 96 | channel = grpc.secure_channel(server_url, creds) 97 | return channel 98 | 99 | def sql_request(self, sql): 100 | token = os.getenv("SQLFLOW_USER_TOKEN", "") 101 | db_conn_str = os.getenv("SQLFLOW_DATASOURCE", "") 102 | submitter = os.getenv("SQLFLOW_SUBMITTER", "") 103 | exit_on_submit_env = os.getenv("SQLFLOW_EXIT_ON_SUBMIT", "True") 104 | user_id = os.getenv("SQLFLOW_USER_ID", "") 105 | hive_location = os.getenv("SQLFLOW_HIVE_LOCATION", "") 106 | hdfs_namenode_addr = os.getenv("SQLFLOW_HDFS_NAMENODE_ADDR", "") 107 | # environment variables JUPYTER_HADOOP_USER, JUPYTER_HADOOP_PASS stores the user's hdfs credentials. 108 | hdfs_user = os.getenv("JUPYTER_HADOOP_USER", "") 109 | hdfs_pass = os.getenv("JUPYTER_HADOOP_PASS", "") 110 | service_account = os.getenv("SQLFLOW_SERVICE_ACCOUNT", "") 111 | wf_namespace = os.getenv("SQLFLOW_WF_NAMESPACE", "") 112 | 113 | if exit_on_submit_env.isdigit(): 114 | exit_on_submit = bool(int(exit_on_submit_env)) 115 | else: 116 | exit_on_submit = exit_on_submit_env.lower() == "true" 117 | se = pb.Session(token=token, 118 | db_conn_str=db_conn_str, 119 | exit_on_submit=exit_on_submit, 120 | user_id=user_id, 121 | hive_location=hive_location, 122 | hdfs_namenode_addr=hdfs_namenode_addr, 123 | hdfs_user=hdfs_user, 124 | hdfs_pass=hdfs_pass, 125 | service_account=service_account, 126 | wf_namespace=wf_namespace, 127 | submitter=submitter) 128 | try: 129 | sql = self._expander.expand(sql) 130 | except Exception as e: 131 | _LOGGER.error("%s", e) 132 | raise e 133 | return pb.Request(sql=sql, session=se) 134 | 135 | def execute(self, operation): 136 | """Run a SQL statement 137 | 138 | :param operation: SQL statement to be executed. 139 | :type operation: str. 140 | 141 | :returns: sqlflow.client.Rows 142 | 143 | Example: 144 | 145 | >>> client.execute("select * from iris limit 1") 146 | 147 | """ 148 | try: 149 | stream_response = self._stub.Run(self.sql_request(operation), timeout=DEFAULT_TIMEOUT) 150 | return self.display(stream_response) 151 | except grpc.RpcError as e: 152 | # NOTE: raise exception to interrupt notebook execution. Or 153 | # the notebook will continue execute the next block. 154 | raise e 155 | except EnvExpanderError as e: 156 | raise e 157 | 158 | def read_fetch_response(self, job): 159 | req = pb.FetchRequest() 160 | pb_job = pb.Job(id=job.id, namespace=job.namespace) 161 | pb_job.id = job.id 162 | pb_job.namespace = job.namespace 163 | req.job.CopyFrom(pb_job) 164 | compound_message = CompoundMessage() 165 | column_names = None 166 | rows = [] 167 | 168 | # TODO(yancey1989): using the common codebase with the stream response 169 | while True: 170 | fetch_response = self._stub.Fetch(req) 171 | for response in fetch_response.responses.response: 172 | rtype = response.WhichOneof('response') 173 | if rtype == 'message': 174 | msg = response.message.message 175 | if msg.startswith(HTML_PREFIX): 176 | from IPython.core.display import display, HTML 177 | display(HTML(msg.lstrip(HTML_PREFIX))) 178 | else: 179 | _LOGGER.info(msg) 180 | elif rtype == 'eoe': 181 | def _rows_gen(): 182 | for row in rows: 183 | yield [self._decode_any(d) for d in row.data] 184 | if rows: 185 | compound_message.add_rows(Rows(column_names, _rows_gen), None) 186 | rows = [] 187 | break 188 | elif rtype == 'head': 189 | column_names = [column_name for column_name in response.head.column_names] 190 | elif rtype == 'row': 191 | rows.append(response.row) 192 | else: 193 | # ignore other response type 194 | pass 195 | if fetch_response.eof: 196 | break 197 | req = fetch_response.updated_fetch_since 198 | time.sleep(2) 199 | if compound_message.empty(): 200 | # return None to avoid outputing a blank line 201 | return None 202 | return compound_message 203 | 204 | def display_html(self, first_line, stream_reader): 205 | resp_list = [first_line] 206 | for res in stream_reader.read_until_type_changed(): 207 | resp_list.append(res.message.message) 208 | from IPython.core.display import display, HTML 209 | display(HTML('\n'.join(resp_list))) 210 | 211 | def display(self, stream_response): 212 | """Display stream response like log or table.row""" 213 | reader = StreamReader(stream_response) 214 | response, rtype = reader.read_one() 215 | compound_message = CompoundMessage() 216 | while True: 217 | if response is None: 218 | break 219 | if rtype == 'message': 220 | if re.match(r'<[a-z][\s\S]*>.*', response.message.message): 221 | self.display_html(response.message.message, reader) 222 | else: 223 | _LOGGER.info(response.message.message) 224 | for response in reader.read_until_type_changed(): 225 | _LOGGER.info(response.message.message) 226 | response = reader.last_response 227 | if response is not None: 228 | rtype = response.WhichOneof('response') 229 | continue 230 | elif rtype == 'job': 231 | job = response.job 232 | # the last response type is Job for the workflow mode, 233 | # so break the loop here 234 | return self.read_fetch_response(job) 235 | elif rtype == 'head' or rtype == 'row': 236 | column_names = [column_name for column_name in response.head.column_names] 237 | 238 | def rows_gen(): 239 | for res in reader.read_until_type_changed(): 240 | yield [self._decode_any(a) for a in res.row.data] 241 | compound_message.add_rows(Rows(column_names, rows_gen), None) 242 | else: 243 | # deal with other response type in the future if necessary. 244 | pass 245 | 246 | # read the next response 247 | response, rtype = reader.read_one() 248 | if compound_message.empty(): 249 | # return None to avoid outputing a blank line 250 | return None 251 | return compound_message 252 | 253 | @classmethod 254 | def _decode_any(cls, any_message): 255 | """Decode a google.protobuf.any_pb2 256 | """ 257 | try: 258 | message = next(getattr(wrapper, type_name)() 259 | for type_name, desc in wrapper.DESCRIPTOR.message_types_by_name.items() 260 | if any_message.Is(desc)) 261 | any_message.Unpack(message) 262 | return message.value 263 | except StopIteration: 264 | if any_message.Is(pb.Row.Null.DESCRIPTOR): 265 | return None 266 | if any_message.Is(Timestamp.DESCRIPTOR): 267 | timestamp_message = Timestamp() 268 | any_message.Unpack(timestamp_message) 269 | return timestamp_message.ToDatetime() 270 | raise TypeError("Unsupported type {}".format(any_message)) 271 | -------------------------------------------------------------------------------- /sqlflow/compound_message.py: -------------------------------------------------------------------------------- 1 | from sqlflow.rows import Rows 2 | 3 | class CompoundMessage: 4 | def __init__(self): 5 | """Message containing return result of several SQL statements 6 | CompoundMessage can not display in notebook since we need to 7 | output log messages for long running training sqls. 8 | """ 9 | self._messages = [] 10 | self.TypeRows = 1 11 | self.TypeMessage = 2 12 | self.TypeHTML = 3 13 | 14 | def add_rows(self, rows, eoe): 15 | assert(isinstance(rows, Rows)) 16 | # call __str__() to trigger rows_gen 17 | rows.__str__() 18 | self._messages.append((rows, eoe, self.TypeRows)) 19 | 20 | def add_message(self, message, eoe): 21 | assert(isinstance(message, str)) 22 | self._messages.append((message, eoe, self.TypeMessage)) 23 | 24 | def add_html(self, message, eoe): 25 | assert(isinstance(message, str)) 26 | self._messages.append((message, eoe, self.TypeHTML)) 27 | 28 | def length(self): 29 | return len(self._messages) 30 | 31 | def __str__(self): 32 | return self.__repr__() 33 | 34 | def __repr__(self): 35 | all_string = "" 36 | for r in self._messages: 37 | if isinstance(r[0], Rows): 38 | all_string = '\n'.join([all_string, r[0].__repr__()]) 39 | else: 40 | all_string = '\n'.join([all_string, r[0].__repr__()]) 41 | return all_string 42 | 43 | def _repr_html_(self): 44 | all_html = "" 45 | for r in self._messages: 46 | if isinstance(r[0], Rows): 47 | all_html = ''.join([all_html, r[0]._repr_html_()]) 48 | else: 49 | all_html = ''.join([all_html, "

%s

" % (r[0].__str__().replace("\n", "
"))]) 50 | return all_html 51 | 52 | def get(self, idx): 53 | return self._messages[idx][0] 54 | 55 | def empty(self): 56 | return len(self._messages) == 0 57 | -------------------------------------------------------------------------------- /sqlflow/env_expand.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from datetime import datetime, timedelta 4 | 5 | 6 | class EnvExpanderError(Exception): 7 | def __init__(self, message): 8 | self.message = message 9 | 10 | def __str__(self): 11 | return repr(self.message) 12 | 13 | class EnvExpander(object): 14 | def __init__(self, environ=os.environ): 15 | self.environ = environ 16 | self.pattern_env = re.compile(r'\$\{(.*?)\}') 17 | # TODO(Yancey1989): support more date expression 18 | self.pattern_date_expr = re.compile(r"(yyyyMMdd)\Z|(yyyyMMdd)\s*(\+|\-)\s*(\d+)d") 19 | 20 | 21 | def _match_date_expr(self, expr): 22 | return re.match(self.pattern_date_expr, expr) 23 | 24 | def parse_bizdate(self, expr): 25 | bizdate = self.environ["BIZDATE"][:8] 26 | if not bizdate: 27 | raise EnvExpanderError("The date format ${yyyyMMdd +/- 1d} need set the OS environment variable ${BIZDATE}") 28 | dt = datetime.strptime(bizdate, "%Y%m%d") 29 | m = self._match_date_expr(expr) 30 | if m.groups()[0]: 31 | return dt.strftime("%Y%m%d") 32 | else: 33 | if m.groups()[2] == "+": 34 | dt += timedelta(days=int(m.groups()[3])) 35 | elif m.groups()[2] == "-": 36 | dt -= timedelta(days=int(m.groups()[3])) 37 | else: 38 | raise EnvExpanderError("date format failed: {}".format(expr)) 39 | return dt.strftime("%Y%m%d") 40 | 41 | def expand(self, sql): 42 | new_sql = sql 43 | for m in re.finditer(self.pattern_env, sql): 44 | if m.group(1) in self.environ: 45 | val = self.environ[m.group(1)] 46 | new_sql = new_sql.replace(m.group(0), val, 1) 47 | elif self._match_date_expr(m.group(1)): 48 | val = self.parse_bizdate(m.group(1)) 49 | new_sql = new_sql.replace(m.group(0), val, 1) 50 | else: 51 | raise EnvExpanderError("Can not find the environment: {} in the runtime envionrment.".format(m.group(1))) 52 | return new_sql -------------------------------------------------------------------------------- /sqlflow/magic.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | import os 5 | import ssl 6 | import sys 7 | import threading 8 | 9 | import dotenv 10 | import nest_asyncio 11 | from IPython.core.magic import Magics, cell_magic, line_magic, magics_class 12 | # For changing input cell syntax highlighting logic for the entire session 13 | # http://stackoverflow.com/questions/28703626/ipython-change-input-cell-syntax-highlighting-logic-for-entire-session 14 | from IPython.display import display_javascript 15 | from tornado import httpclient 16 | 17 | from sqlflow.client import _LOGGER, Client 18 | 19 | nest_asyncio.apply() 20 | 21 | 22 | @magics_class 23 | class SqlFlowMagic(Magics): 24 | """ 25 | Provides the `%%sqlflow` magic 26 | """ 27 | 28 | def __init__(self, shell): 29 | super(SqlFlowMagic, self).__init__(shell) 30 | self.client = None 31 | 32 | @cell_magic('sqlflow') 33 | def execute(self, line, cell): 34 | """Runs SQL statement 35 | 36 | :param line: The line magic 37 | :type line: str. 38 | :param cell: The cell magic 39 | :type cell: str. 40 | 41 | Example: 42 | 43 | >>> %%sqlflow SELECT * 44 | ... FROM mytable 45 | 46 | >>> %%sqlflow SELECT * 47 | ... FROM iris.iris limit 1 48 | ... TRAIN DNNClassifier 49 | ... WITH 50 | ... n_classes = 3, 51 | ... hidden_units = [10, 10] 52 | ... COLUMN sepal_length, sepal_width, petal_length, petal_width 53 | ... LABEL class 54 | ... INTO my_dnn_model; 55 | 56 | """ 57 | self.lazy_load() 58 | return self.client.execute('\n'.join([line, cell])) 59 | 60 | def lazy_load(self): 61 | self.create_db_on_demaond() 62 | if not self.client: 63 | self.client = Client() 64 | 65 | def get_ssl_ctx(self): 66 | ca_path = os.getenv("SQLFLOW_PLAYGROUND_SERVER_CA") 67 | client_key = os.getenv("SQLFLOW_PLAYGROUND_CLIENT_KEY") 68 | client_cert = os.getenv("SQLFLOW_PLAYGROUND_CLIENT_CERT") 69 | if not (ca_path and client_cert and client_key): 70 | raise ValueError("Certification files is not configured" 71 | "for this client") 72 | ssl_ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) 73 | ssl_ctx.load_cert_chain(client_cert, client_key) 74 | ssl_ctx.load_verify_locations(ca_path) 75 | ssl_ctx.check_hostname = False 76 | 77 | return ssl_ctx 78 | 79 | def create_db_on_demaond(self): 80 | """If we are connecting to a Sqlflow Playground Service, 81 | we need to ask the server to get DB and SQLFlow resource. 82 | In this case, we do not specify SQLFLOW_DATASOURCE, instead 83 | we should specify the SQLFLOW_PLAYGROUND_SERVRE 84 | variable. On first time %%sqlflow command is executed, 85 | we check if the DB connection is retrived from the server, 86 | and create a new one if we haven't done it. 87 | 88 | Client should be secured by a certification file. Use 89 | SQLFLOW_PLAYGROUND_CLIENT_KEY and SQLFLOW_PLAYGROUND_CLIENT_CERT 90 | and SQLFLOW_PLAYGROUND_SERVER_CA to specify them. These 91 | files can be got from SQLFlow Playground Server maintainer. 92 | """ 93 | 94 | if os.getenv("SQLFLOW_DATASOURCE"): 95 | return 96 | # If no datasource is given, try to connect to 97 | # our playground server to create one 98 | user_id_env = os.getenv("SQLFLOW_PLAYGROUND_USER_ID_ENV") 99 | user_id = os.getenv(user_id_env) 100 | server = os.getenv("SQLFLOW_PLAYGROUND_SERVRE") 101 | if not server: 102 | raise ValueError("Neither a datasource nor a " 103 | "playground server is given.") 104 | if not user_id: 105 | raise ValueError( 106 | "Need to specify a SQLFLOW_PLAYGROUND_USER_ID_ENV") 107 | # give user some hint, this may take a few seconds 108 | from IPython.core.display import display, HTML 109 | display(HTML("Loading resource...")) 110 | # create db pod on playground service 111 | body = { 112 | "user_id": user_id, 113 | } 114 | ssl_ctx = self.get_ssl_ctx() 115 | http = httpclient.HTTPClient() 116 | try: 117 | req = httpclient.HTTPRequest( 118 | "%s/api/create_db" % server, method="POST", 119 | ssl_options=ssl_ctx, body=json.dumps(body)) 120 | resp = http.fetch(req) 121 | result = json.loads(resp.body) 122 | os.environ["SQLFLOW_DATASOURCE"] = result["data_source"] 123 | 124 | # server can kill the db resource 125 | # if no heart beat is coming for a while 126 | self.setup_heart_beat(server, user_id) 127 | except Exception as e: 128 | raise RuntimeError("Can't get SQLFlow resource, because of", e) 129 | finally: 130 | http.close() 131 | 132 | def setup_heart_beat(self, server, user_id): 133 | http = httpclient.HTTPClient() 134 | ssl_ctx = self.get_ssl_ctx() 135 | 136 | async def report(): 137 | while True: 138 | try: 139 | url = "%s/api/heart_beat?user_id=%s" % (server, user_id) 140 | http.fetch(url, ssl_options=ssl_ctx) 141 | await asyncio.sleep(10 * 60) 142 | except: 143 | pass 144 | asyncio.ensure_future(report()) 145 | 146 | 147 | def load_ipython_extension(ipython): 148 | if os.getenv("SQLFLOW_JUPYTER_ENV_PATH"): 149 | dotenv.load_dotenv(os.environ["SQLFLOW_JUPYTER_ENV_PATH"]) 150 | 151 | # Change input cell syntax highlighting to SQL 152 | js = "IPython.CodeCell.options_default.highlight_modes['magic_sql'] = {'reg':[/^%%sqlflow/]};" 153 | display_javascript(js, raw=True) 154 | 155 | magics = SqlFlowMagic(ipython) 156 | ipython.register_magics(magics) 157 | -------------------------------------------------------------------------------- /sqlflow/rows.py: -------------------------------------------------------------------------------- 1 | class Rows: 2 | def __init__(self, column_names, rows_gen): 3 | """Query result of sqlflow.client.Client.execute 4 | 5 | :param column_names: column names 6 | :type column_names: list[str]. 7 | :param rows_gen: rows generator 8 | :type rows_gen: generator 9 | """ 10 | self._column_names = column_names 11 | self._rows_gen = rows_gen 12 | self._rows = None 13 | 14 | def column_names(self): 15 | """Column names 16 | 17 | :return: list[str] 18 | """ 19 | return self._column_names 20 | 21 | def rows(self): 22 | """Rows 23 | 24 | Example: 25 | 26 | >>> [r for r in rows.rows()] 27 | 28 | :return: list generator 29 | """ 30 | if self._rows is None: 31 | self._rows = [] 32 | for row in self._rows_gen(): 33 | self._rows.append(row) 34 | yield row 35 | else: 36 | for row in self._rows: 37 | yield row 38 | 39 | def __str__(self): 40 | return self.__repr__() 41 | 42 | def __repr__(self): 43 | return self.to_dataframe().__repr__() 44 | 45 | def _repr_html_(self): 46 | return self.to_dataframe()._repr_html_() 47 | 48 | def to_dataframe(self): 49 | """Convert Rows to pandas.Dataframe 50 | 51 | :return: pandas.Dataframe 52 | """ 53 | # Fetch all the rows to self._rows 54 | for r in self.rows(): 55 | pass 56 | import pandas as pd 57 | return pd.DataFrame(self._rows, columns=self._column_names) 58 | 59 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sql-machine-learning/pysqlflow/ec63095effb9e109c56ee8bd608e16f3a0492887/tests/__init__.py -------------------------------------------------------------------------------- /tests/mock_servicer.py: -------------------------------------------------------------------------------- 1 | import grpc 2 | from concurrent import futures 3 | 4 | import google.protobuf.wrappers_pb2 as wrapper 5 | 6 | import sqlflow.proto.sqlflow_pb2 as pb 7 | import sqlflow.proto.sqlflow_pb2_grpc as pb_grpc 8 | 9 | class MockWorkflowServicer(pb_grpc.SQLFlowServicer): 10 | """ 11 | server implementation with workflow 12 | """ 13 | def Run(self, request, context): 14 | yield MockWorkflowServicer.job_response("sqlflow_couler_xxx") 15 | 16 | def Fetch(self, request, context): 17 | return MockWorkflowServicer.fetch_response(request, "fetch workflow logs") 18 | 19 | @staticmethod 20 | def job_response(job_id): 21 | pb_job = pb.Job() 22 | pb_job.id = job_id 23 | 24 | res = pb.Response() 25 | res.job.CopyFrom(pb_job) 26 | 27 | return res 28 | 29 | @staticmethod 30 | def fetch_response(req, log): 31 | pb_res = pb.FetchResponse() 32 | pb_res.updated_fetch_since.CopyFrom(req) 33 | pb_res.eof = True 34 | gen = MockServicer.table_response(MockServicer.get_test_table()) 35 | rows = [] 36 | for row in gen: 37 | rows.append(row) 38 | pb_res.responses.response.extend(rows) 39 | pb_res.responses.response.extend([MockServicer.eoe_response()]) 40 | return pb_res 41 | 42 | class MockServicer(pb_grpc.SQLFlowServicer): 43 | """ 44 | server implementation 45 | """ 46 | def Run(self, request, context): 47 | SQL = request.sql.upper() 48 | if "SELECT" in SQL: 49 | if "TRAIN" in SQL or "PREDICT" in SQL: 50 | for i in range(3): 51 | yield MockServicer.message_response("extended sql") 52 | else: 53 | for res in MockServicer.table_response(MockServicer.get_test_table()): 54 | yield res 55 | elif SQL == "TEST VERIFY SESSION": 56 | # TODO(Yancey1989): using a elegant way to test the session instead of the trick. 57 | yield MockServicer.message_response("|".join([request.session.token, request.session.db_conn_str, str(request.session.exit_on_submit), request.session.user_id])) 58 | elif SQL == "TEST RENDER HTML": 59 | yield MockServicer.message_response("
") 60 | else: 61 | yield MockServicer.message_response('bad request', 0) 62 | 63 | def Fetch(self, request, context): 64 | job_id = request.job.id 65 | yield MockServicer.message_fetch_response(request.job, "execute workflow") 66 | 67 | @staticmethod 68 | def get_test_table(): 69 | return {"column_names": ['x', 'y'], "rows": [[1, 2], [3, 4]]} 70 | 71 | @staticmethod 72 | def wrap_value(value): 73 | if isinstance(value, bool): 74 | message = wrapper.BoolValue() 75 | message.value = value 76 | elif isinstance(value, int): 77 | message = wrapper.Int64Value() 78 | message.value = value 79 | elif isinstance(value, float): 80 | message = wrapper.DoubleValue() 81 | message.value = value 82 | else: 83 | raise Exception("Unsupported type {}".format(type(value))) 84 | return message 85 | 86 | @staticmethod 87 | def table_response(table): 88 | res = pb.Response() 89 | head = pb.Head() 90 | for name in table['column_names']: 91 | head.column_names.append(name) 92 | res.head.CopyFrom(head) 93 | yield res 94 | 95 | for row in table['rows']: 96 | res = pb.Response() 97 | row_message = pb.Row() 98 | for data in row: 99 | row_message.data.add().Pack(MockServicer.wrap_value(data)) 100 | res.row.CopyFrom(row_message) 101 | yield res 102 | 103 | @staticmethod 104 | def message_response(message): 105 | pb_msg = pb.Message() 106 | pb_msg.message = message 107 | 108 | res = pb.Response() 109 | res.message.CopyFrom(pb_msg) 110 | return res 111 | 112 | @staticmethod 113 | def eoe_response(): 114 | pb_eoe = pb.EndOfExecution() 115 | res = pb.Response() 116 | res.eoe.CopyFrom(pb_eoe) 117 | return res 118 | 119 | 120 | def _server(server_instance, port, event, ca_crt, ca_key): 121 | svr = grpc.server(futures.ThreadPoolExecutor(max_workers=2)) 122 | with open(ca_key, "rb") as f: 123 | private_key = f.read() 124 | with open(ca_crt, "rb") as f: 125 | certification_chain = f.read() 126 | server_credentials = grpc.ssl_server_credentials( ( (private_key, certification_chain), ) ) 127 | pb_grpc.add_SQLFlowServicer_to_server(server_instance, svr) 128 | svr.add_secure_port('[::]:%d' % port, server_credentials) 129 | svr.start() 130 | try: 131 | event.wait() 132 | except KeyboardInterrupt: 133 | pass 134 | svr.stop(0) 135 | -------------------------------------------------------------------------------- /tests/test_client.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import threading 3 | import time 4 | import os 5 | from unittest import mock 6 | import tempfile 7 | import subprocess 8 | import shutil 9 | 10 | from sqlflow.client import Client 11 | from tests.mock_servicer import _server, MockServicer, MockWorkflowServicer 12 | 13 | from google.protobuf.timestamp_pb2 import Timestamp 14 | from google.protobuf.any_pb2 import Any 15 | import sqlflow.proto.sqlflow_pb2 as pb 16 | 17 | def generateTempCA(): 18 | tmp_dir = tempfile.mkdtemp(suffix="sqlflow_ssl", dir="/tmp") 19 | ca_key = os.path.join(tmp_dir, "ca.key") 20 | ca_csr = os.path.join(tmp_dir, "ca.csr") 21 | ca_crt = os.path.join(tmp_dir, "ca.crt") 22 | 23 | assert subprocess.call(["openssl", "genrsa", "-out", ca_key, "2048"]) == 0 24 | assert subprocess.call(["openssl", "req", "-nodes", "-new", "-key", ca_key, "-subj", "/CN=localhost", "-out", ca_csr]) == 0 25 | assert subprocess.call(["openssl", "x509", "-req", "-sha256", "-days", "365", "-in", ca_csr, "-signkey", ca_key, "-out", ca_crt]) == 0 26 | 27 | return tmp_dir, ca_crt, ca_key 28 | 29 | 30 | class ClientServerTest(unittest.TestCase): 31 | @classmethod 32 | def setUpClass(cls): 33 | # TODO: free port is better 34 | port = 8765 35 | cls.server_url = "localhost:%d" % port 36 | cls.event = threading.Event() 37 | cls.tmp_ca_dir, cls.ca_crt, ca_key = generateTempCA() 38 | threading.Thread(target=_server, args=[MockServicer(), port, cls.event, cls.ca_crt, ca_key]).start() 39 | # wait for start 40 | time.sleep(1) 41 | cls.client = Client(cls.server_url, cls.ca_crt) 42 | 43 | @classmethod 44 | def tearDownClass(cls): 45 | # shutdown server after this test 46 | cls.event.set() 47 | shutil.rmtree(cls.tmp_ca_dir, ignore_errors=True) 48 | 49 | def test_execute_stream(self): 50 | with mock.patch('sqlflow.client._LOGGER') as log_mock: 51 | res = self.client.execute("select * from galaxy train ..") 52 | log_mock.info.assert_called_with("extended sql") 53 | 54 | expected_table = MockServicer.get_test_table() 55 | rows = self.client.execute("select * from galaxy").get(0) 56 | assert expected_table["column_names"] == rows.column_names() 57 | assert expected_table["rows"] == [r for r in rows.rows()] 58 | 59 | def test_cmd(self): 60 | assert subprocess.call(["sqlflow", "--url", self.server_url, 61 | "--ca_crt", self.ca_crt, 62 | "select * from galaxy"]) == 0 63 | 64 | def test_decode_time(self): 65 | any_message = Any() 66 | timestamp_message = Timestamp() 67 | timestamp_message.GetCurrentTime() 68 | any_message.Pack(timestamp_message) 69 | assert timestamp_message.ToDatetime() == Client._decode_any(any_message) 70 | 71 | def test_decode_null(self): 72 | any_message = Any() 73 | null_message = pb.Row.Null() 74 | any_message.Pack(null_message) 75 | assert Client._decode_any(any_message) is None 76 | 77 | # TODO(typhoonzero): this test seems not useful, need to find a better way 78 | # def test_session(self): 79 | # token = "unittest-user" 80 | # ds = "maxcompute://AK:SK@host:port" 81 | # os.environ["SQLFLOW_USER_TOKEN"] = token 82 | # os.environ["SQLFLOW_DATASOURCE"] = ds 83 | # os.environ["SQLFLOW_EXIT_ON_SUBMIT"] = "True" 84 | # os.environ["SQLFLOW_USER_ID"] = "sqlflow_user" 85 | # os.environ["SQLFLOW_SUBMITTER"] = "pai" 86 | # with mock.patch('sqlflow.client._LOGGER') as log_mock: 87 | # self.client.execute("TEST VERIFY SESSION") 88 | # log_mock.debug.assert_called_with("|".join([token, ds, "True", "sqlflow_user"])) 89 | 90 | def test_draw_html(self): 91 | from IPython.core.display import display, HTML 92 | with mock.patch('IPython.core.display.HTML') as log_mock: 93 | self.client.execute("TEST RENDER HTML") 94 | log_mock.assert_called_with("
") 95 | 96 | class ClientWorkflowServerTest(unittest.TestCase): 97 | @classmethod 98 | def setUpClass(cls): 99 | # TODO: free port is better 100 | port = 8766 101 | cls.server_url = "localhost:%d" % port 102 | cls.event = threading.Event() 103 | cls.tmp_ca_dir, cls.ca_crt, ca_key = generateTempCA() 104 | threading.Thread(target=_server, args=[MockWorkflowServicer(), port, cls.event, cls.ca_crt, ca_key]).start() 105 | # wait for start 106 | time.sleep(1) 107 | cls.client = Client(cls.server_url, cls.ca_crt) 108 | 109 | @classmethod 110 | def tearDownClass(cls): 111 | # shutdown server after this test 112 | cls.event.set() 113 | shutil.rmtree(cls.tmp_ca_dir, ignore_errors=True) 114 | 115 | def test_execute_stream(self): 116 | expected_table = MockServicer.get_test_table() 117 | rows = self.client.execute("select * from galaxy train ..").get(0) 118 | assert expected_table["column_names"] == rows.column_names() 119 | assert expected_table["rows"] == [r for r in rows.rows()] 120 | -------------------------------------------------------------------------------- /tests/test_env_expand.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from sqlflow.env_expand import EnvExpander, EnvExpanderError 4 | 5 | class EnvExpanderTest(unittest.TestCase): 6 | @classmethod 7 | def setUpClass(cls): 8 | os.environ["BIZDATE"] = "20190731" 9 | os.environ["t1"] = "tablename" 10 | cls.expander = EnvExpander(os.environ) 11 | 12 | def test_expand(self): 13 | sql = "SELECT * from ${t1} where pt=${yyyyMMdd}" 14 | new_sql = self.expander.expand(sql) 15 | expected_sql = "SELECT * from tablename where pt=20190731" 16 | assert new_sql == expected_sql 17 | 18 | def test_bizdate_delta(self): 19 | sql = "SELECT * from ${t1} where pt=${yyyyMMdd + 1d}" 20 | new_sql = self.expander.expand(sql) 21 | expected_sql = "SELECT * from tablename where pt=20190801" 22 | assert new_sql == expected_sql 23 | 24 | def test_bizdate_arbitrary_space(self): 25 | sql = "SELECT * from ${t1} where pt=${yyyyMMdd+ 1d}" 26 | new_sql = self.expander.expand(sql) 27 | expected_sql = "SELECT * from tablename where pt=20190801" 28 | assert new_sql == expected_sql 29 | 30 | def test_expand_error(self): 31 | sql = "SELECT * from ${no_exists} where pt=${yyyyMMdd + 1d}" 32 | with self.assertRaises(EnvExpanderError): 33 | self.expander.expand(sql) 34 | 35 | -------------------------------------------------------------------------------- /tests/test_version.py: -------------------------------------------------------------------------------- 1 | import sqlflow 2 | 3 | 4 | def test_answer(): 5 | assert sqlflow.__version__ == sqlflow._version.__version__ 6 | --------------------------------------------------------------------------------