├── .coveragerc ├── .gitignore ├── .travis.yml ├── CHANGELOG.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── gsheetsdb ├── __init__.py ├── __version__.py ├── auth.py ├── console.py ├── convert.py ├── db.py ├── dialect.py ├── exceptions.py ├── formatting.py ├── processors.py ├── query.py ├── sqlite.py ├── translator.py ├── types.py ├── url.py └── utils.py ├── requirements.txt ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── context.py ├── test_console.py ├── test_convert.py ├── test_db.py ├── test_dialect.py ├── test_processing.py ├── test_query.py ├── test_translation.py ├── test_url.py └── test_utils.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | exclude_lines = 3 | # Have to re-enable the standard pragma 4 | pragma: no cover 5 | 6 | # Don't complain if tests don't hit defensive assertion code: 7 | raise NotImplementedError 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | python: 4 | - 3.4 5 | - 3.5 6 | - 3.6 7 | matrix: 8 | include: 9 | - python: 3.7 10 | dist: xenial 11 | sudo: true 12 | 13 | install: 14 | - pip install codecov 15 | - pip install -e .[cli,dev,sqlalchemy] 16 | 17 | script: 18 | - flake8 19 | - py.test --cov=gsheetsdb/ 20 | 21 | after_success: 22 | - codecov 23 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## [0.1.0] - 2018-09-16 4 | - Initial release providing a Python DB API SQL interface to Google spreadsheets and a CLI. 5 | 6 | ## [0.1.1] - 2018-09-16 7 | - Added SQLAlchemy dialect. 8 | - Allow headers and gid to be passed on the URL. 9 | 10 | ## [0.1.2] - 2018-09-16 11 | - Add missing dependency to `moz-sql-parser` to `setup.py`. 12 | 13 | ## [0.1.3] - 2018-09-16 14 | - Fix small bug in SQL Alchemy compiler. 15 | - Allow aliases in `ORDER BY`. 16 | 17 | ## [0.1.4] - 2018-09-16 18 | - Fix `visit_column` method. 19 | 20 | ## [0.1.5] - 2018-10-25 21 | - Parse dates, better error message. 22 | - `COUNT(*)` working. 23 | - Fallback to SQLite if query fails. 24 | - Custom date truncation using `DATETRUNC`. 25 | 26 | ## [0.1.6] - 2018-10-30 27 | - Handle authentication. 28 | - Fix cursor description in SQLite fallback. 29 | 30 | ## [0.1.7] - 2018-12-06 31 | - Fix session when no credentials are passed. 32 | 33 | ## [0.1.8] - 2018-12-07 34 | - Add logging. 35 | - Fix `CREATE TABLE` when sheet has no headers. 36 | 37 | ## [0.1.9] - 2019-01-07 38 | - Ensure URL is from docs.google.com. 39 | 40 | ## [0.1.10] - 2020-05-06 41 | - Strip blank columns caused by edits. 42 | 43 | ## [0.1.11] - 2020-09-21 44 | - Handle subqueries by falling back to Sqlite 45 | 46 | ## [0.1.12] - 2020-12-04 47 | - Implement `do_ping` method always returning true 48 | 49 | ## [0.1.13] - 2021-02-17 50 | - Handle escaped periods in URL 51 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Beto Dealmeida 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | init: requirements 2 | pip install -r requirements.txt 3 | 4 | test: 5 | nosetests tests 6 | 7 | requirements: 8 | pipreqs --force druiddb --savepath requirements.txt 9 | 10 | .PHONY: init test requirements 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/betodealmeida/gsheets-db-api.svg?branch=master)](https://travis-ci.org/betodealmeida/gsheets-db-api) [![codecov](https://codecov.io/gh/betodealmeida/gsheets-db-api/branch/master/graph/badge.svg)](https://codecov.io/gh/betodealmeida/gsheets-db-api) 2 | 3 | **Note:** [shillelagh](https://github.com/betodealmeida/shillelagh/) is a drop-in replacement for `gsheets-db-api`, with many additional features. You should use it instead. If you're using SQLAlchemy all you need to do: 4 | 5 | ```bash 6 | $ pip uninstall gsheetsdb 7 | $ pip install shillelagh 8 | ``` 9 | 10 | If you're using the DB API: 11 | 12 | ```bash 13 | # from gsheetsdb import connect 14 | from shillelagh.backends.apsw.db import connect 15 | ``` 16 | 17 | # A Python DB API 2.0 for Google Spreadsheets # 18 | 19 | This module allows you to query Google Spreadsheets using SQL. 20 | 21 | Using [this spreadsheet](https://docs.google.com/spreadsheets/d/1_rN3lm0R_bU3NemO0s9pbFkY5LQPcuy1pscv8ZXPtg8/) as an example: 22 | 23 | | | A | B | 24 | |-|--------|-----| 25 | | 1 | country | cnt | 26 | | 2 | BR | 1 | 27 | | 3 | BR | 3 | 28 | | 4 | IN | 5 | 29 | 30 | Here's a simple query using the Python API: 31 | 32 | ```python 33 | from gsheetsdb import connect 34 | 35 | conn = connect() 36 | result = conn.execute(""" 37 | SELECT 38 | country 39 | , SUM(cnt) 40 | FROM 41 | "https://docs.google.com/spreadsheets/d/1_rN3lm0R_bU3NemO0s9pbFkY5LQPcuy1pscv8ZXPtg8/" 42 | GROUP BY 43 | country 44 | """, headers=1) 45 | for row in result: 46 | print(row) 47 | ``` 48 | 49 | This will print: 50 | 51 | ``` 52 | Row(country='BR', sum_cnt=4.0) 53 | Row(country='IN', sum_cnt=5.0) 54 | ``` 55 | 56 | ## How it works ## 57 | 58 | ### Transpiling ### 59 | 60 | Google spreadsheets can actually be queried with a [very limited SQL API](https://developers.google.com/chart/interactive/docs/querylanguage). This module will transpile the SQL query into a simpler query that the API understands. Eg, the query above would be translated to: 61 | 62 | ```sql 63 | SELECT A, SUM(B) GROUP BY A 64 | ``` 65 | 66 | ### Processors ### 67 | 68 | In addition to transpiling, this module also provides pre- and post-processors. The pre-processors add more columns to the query, and the post-processors build the actual result from those extra columns. Eg, `COUNT(*)` is not supported, so the following query: 69 | 70 | ```sql 71 | SELECT COUNT(*) FROM "https://docs.google.com/spreadsheets/d/1_rN3lm0R_bU3NemO0s9pbFkY5LQPcuy1pscv8ZXPtg8/" 72 | ``` 73 | 74 | Gets translated to: 75 | 76 | ```sql 77 | SELECT COUNT(A), COUNT(B) 78 | ``` 79 | 80 | And then the maximum count is returned. This assumes that at least one column has no `NULL`s. 81 | 82 | 83 | ### SQLite ### 84 | When a query can't be expressed, the module will issue a `SELECT *`, load the data into an in-memory SQLite table, and execute the query in SQLite. This is obviously inneficient, since all data has to be downloaded, but ensures that all queries succeed. 85 | 86 | ## Installation ## 87 | 88 | ```bash 89 | $ pip install gsheetsdb 90 | $ pip install gsheetsdb[cli] # if you want to use the CLI 91 | $ pip install gsheetsdb[sqlalchemy] # if you want to use it with SQLAlchemy 92 | ``` 93 | 94 | ## CLI ## 95 | 96 | The module will install an executable called `gsheetsdb`: 97 | 98 | ```bash 99 | $ gsheetsdb --headers=1 100 | > SELECT * FROM "https://docs.google.com/spreadsheets/d/1_rN3lm0R_bU3NemO0s9pbFkY5LQPcuy1pscv8ZXPtg8/" 101 | country cnt 102 | --------- ----- 103 | BR 1 104 | BR 3 105 | IN 5 106 | > SELECT country, SUM(cnt) FROM "https://docs.google.com/spreadsheets/d/1_rN3lm0R_bU3NemO0s9pbFkY5LQPcuy1 107 | pscv8ZXPtg8/" GROUP BY country 108 | country sum cnt 109 | --------- --------- 110 | BR 4 111 | IN 5 112 | > 113 | ``` 114 | 115 | ## SQLAlchemy support ## 116 | 117 | This module provides a SQLAlchemy dialect. You don't need to specify a URL, since the spreadsheet is extracted from the `FROM` clause: 118 | 119 | ```python 120 | from sqlalchemy import * 121 | from sqlalchemy.engine import create_engine 122 | from sqlalchemy.schema import * 123 | 124 | engine = create_engine('gsheets://') 125 | inspector = inspect(engine) 126 | 127 | table = Table( 128 | 'https://docs.google.com/spreadsheets/d/1_rN3lm0R_bU3NemO0s9pbFkY5LQPcuy1pscv8ZXPtg8/edit#gid=0', 129 | MetaData(bind=engine), 130 | autoload=True) 131 | query = select([func.count(table.columns.country)], from_obj=table) 132 | print(query.scalar()) # prints 3.0 133 | ``` 134 | 135 | Alternatively, you can initialize the engine with a "catalog". The catalog is a Google spreadsheet where each row points to another Google spreadsheet, with URL, number of headers and schema as the columns. You can see an example [here](https://docs.google.com/spreadsheets/d/1AAqVVSpGeyRZyrr4n--fb_IxhLwwKtLbjfu4h6MyyYA/edit#gid=0): 136 | 137 | || A | B | C | 138 | |-|-|-|-| 139 | | 1 | https://docs.google.com/spreadsheets/d/1_rN3lm0R_bU3NemO0s9pbFkY5LQPcuy1pscv8ZXPtg8/edit#gid=0 | 1 | default | 140 | | 2 | https://docs.google.com/spreadsheets/d/1_rN3lm0R_bU3NemO0s9pbFkY5LQPcuy1pscv8ZXPtg8/edit#gid=1077884006 | 2 | default | 141 | 142 | This will make the two spreadsheets above available as "tables" in the `default` schema. 143 | 144 | 145 | ## Authentication ## 146 | 147 | You can access spreadsheets that are shared only within an organization. In order to do this, first [create a service account](https://developers.google.com/api-client-library/python/auth/service-accounts#creatinganaccount). Make sure you select "Enable G Suite Domain-wide Delegation". Download the key as a JSON file. 148 | 149 | Next, you need to manage API client access at https://admin.google.com/${DOMAIN}/AdminHome?chromeless=1#OGX:ManageOauthClients. Add the "Unique ID" from the previous step as the "Client Name", and add `https://spreadsheets.google.com/feeds` as the scope. 150 | 151 | Now, when creating the connection from the DB API or from SQLAlchemy you can point to the JSON file and the user you want to impersonate: 152 | 153 | ```python 154 | >>> auth = {'service_account_file': '/path/to/certificate.json', 'subject': 'user@domain.com'} 155 | >>> conn = connect(auth) 156 | ``` 157 | 158 | -------------------------------------------------------------------------------- /gsheetsdb/__init__.py: -------------------------------------------------------------------------------- 1 | from gsheetsdb.db import connect 2 | from gsheetsdb.exceptions import ( 3 | DataError, 4 | DatabaseError, 5 | Error, 6 | IntegrityError, 7 | InterfaceError, 8 | InternalError, 9 | NotSupportedError, 10 | OperationalError, 11 | ProgrammingError, 12 | Warning, 13 | ) 14 | 15 | 16 | __all__ = [ 17 | 'connect', 18 | 'apilevel', 19 | 'threadsafety', 20 | 'paramstyle', 21 | 'DataError', 22 | 'DatabaseError', 23 | 'Error', 24 | 'IntegrityError', 25 | 'InterfaceError', 26 | 'InternalError', 27 | 'NotSupportedError', 28 | 'OperationalError', 29 | 'ProgrammingError', 30 | 'Warning', 31 | ] 32 | 33 | 34 | apilevel = '2.0' 35 | # Threads may share the module and connections 36 | threadsafety = 2 37 | paramstyle = 'pyformat' 38 | -------------------------------------------------------------------------------- /gsheetsdb/__version__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.13" 2 | -------------------------------------------------------------------------------- /gsheetsdb/auth.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import json 7 | 8 | from google.oauth2 import service_account 9 | 10 | 11 | # Google API scopes for authentication 12 | # https://developers.google.com/chart/interactive/docs/spreadsheets 13 | SCOPES = ['https://spreadsheets.google.com/feeds'] 14 | 15 | 16 | def get_credentials_from_auth( 17 | service_account_file=None, 18 | service_account_info=None, 19 | subject=None, 20 | ): 21 | if service_account_file: 22 | with open(service_account_file) as fp: 23 | service_account_info = json.load(fp) 24 | 25 | if not service_account_info: 26 | return None 27 | 28 | credentials = service_account.Credentials.from_service_account_info( 29 | service_account_info, scopes=SCOPES, subject=subject) 30 | 31 | return credentials 32 | -------------------------------------------------------------------------------- /gsheetsdb/console.py: -------------------------------------------------------------------------------- 1 | """Google Spreadsheets CLI 2 | 3 | Usage: 4 | gsheetsdb [--headers=] [--raise] [--service-account-file= [--subject=]] 5 | gsheetsdb (-h | --help) 6 | gsheetsdb --version 7 | 8 | Options: 9 | -h --help Show this screen. 10 | --version Show version. 11 | --headers= How many rows are headers [default: 0] 12 | --service-account-file= Service account file for authentication 13 | --subject= Subject to impersonate 14 | 15 | """ # noqa: E501 16 | 17 | from __future__ import unicode_literals 18 | 19 | import os 20 | 21 | from docopt import docopt 22 | from prompt_toolkit import prompt 23 | from prompt_toolkit.history import FileHistory 24 | from prompt_toolkit.completion import WordCompleter 25 | from prompt_toolkit.lexers import PygmentsLexer 26 | from prompt_toolkit.styles.pygments import style_from_pygments_cls 27 | from pygments.lexers import SqlLexer 28 | from pygments.styles import get_style_by_name 29 | from tabulate import tabulate 30 | 31 | from gsheetsdb import connect, __version__ 32 | from gsheetsdb.auth import get_credentials_from_auth 33 | 34 | 35 | keywords = [ 36 | 'and', 37 | 'asc', 38 | 'by', 39 | 'date', 40 | 'datetime', 41 | 'desc', 42 | 'false', 43 | 'format', 44 | 'group', 45 | 'label', 46 | 'limit', 47 | 'not', 48 | 'offset', 49 | 'options', 50 | 'or', 51 | 'order', 52 | 'pivot', 53 | 'select', 54 | 'timeofday', 55 | 'timestamp', 56 | 'true', 57 | 'where', 58 | ] 59 | 60 | aggregate_functions = [ 61 | 'avg', 62 | 'count', 63 | 'max', 64 | 'min', 65 | 'sum', 66 | ] 67 | 68 | scalar_functions = [ 69 | 'year', 70 | 'month', 71 | 'day', 72 | 'hour', 73 | 'minute', 74 | 'second', 75 | 'millisecond', 76 | 'quarter', 77 | 'dayOfWeek', 78 | 'now', 79 | 'dateDiff', 80 | 'toDate', 81 | 'upper', 82 | 'lower', 83 | ] 84 | 85 | 86 | def main(): 87 | history = FileHistory(os.path.expanduser('~/.gsheetsdb_history')) 88 | 89 | arguments = docopt(__doc__, version=__version__.__version__) 90 | 91 | auth = { 92 | 'service_account_file': arguments['--service-account-file'], 93 | 'subject': arguments['--subject'], 94 | } 95 | credentials = get_credentials_from_auth(**auth) 96 | connection = connect(credentials) 97 | headers = int(arguments['--headers']) 98 | cursor = connection.cursor() 99 | 100 | lexer = PygmentsLexer(SqlLexer) 101 | words = keywords + aggregate_functions + scalar_functions 102 | completer = WordCompleter(words, ignore_case=True) 103 | style = style_from_pygments_cls(get_style_by_name('manni')) 104 | 105 | while True: 106 | try: 107 | query = prompt( 108 | 'sql> ', lexer=lexer, completer=completer, 109 | style=style, history=history) 110 | except (EOFError, KeyboardInterrupt): 111 | break # Control-D pressed. 112 | 113 | # run query 114 | query = query.strip('; ').replace('%', '%%') 115 | if query: 116 | try: 117 | result = cursor.execute(query, headers=headers) 118 | except Exception as e: 119 | if arguments['--raise']: 120 | raise 121 | print(e) 122 | continue 123 | 124 | columns = [t[0] for t in cursor.description or []] 125 | print(tabulate(result, headers=columns)) 126 | 127 | print('See ya!') 128 | -------------------------------------------------------------------------------- /gsheetsdb/convert.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from collections import namedtuple 7 | import datetime 8 | 9 | 10 | def parse_datetime(v): 11 | """Parse a string like 'Date(2018,0,1,0,0,0)'""" 12 | args = [int(number) for number in v[len('Date('):-1].split(',')] 13 | args[1] += 1 # month is zero indexed in the response 14 | return datetime.datetime(*args) 15 | 16 | 17 | def parse_date(v): 18 | """Parse a string like 'Date(2018,0,1)'""" 19 | args = [int(number) for number in v[len('Date('):-1].split(',')] 20 | args[1] += 1 # month is zero indexed in the response 21 | return datetime.date(*args) 22 | 23 | 24 | def parse_timeofday(v): 25 | return datetime.time(*v) 26 | 27 | 28 | converters = { 29 | 'string': lambda v: v, 30 | 'number': lambda v: v, 31 | 'boolean': lambda v: v, 32 | 'date': parse_date, 33 | 'datetime': parse_datetime, 34 | 'timeofday': parse_timeofday, 35 | } 36 | 37 | 38 | def convert_rows(cols, rows): 39 | Row = namedtuple( 40 | 'Row', 41 | [col['label'].replace(' ', '_') for col in cols], 42 | rename=True) 43 | 44 | results = [] 45 | for row in rows: 46 | values = [] 47 | for i, col in enumerate(row['c']): 48 | if i < len(cols): 49 | converter = converters[cols[i]['type']] 50 | values.append(converter(col['v']) if col else None) 51 | results.append(Row(*values)) 52 | 53 | return results 54 | -------------------------------------------------------------------------------- /gsheetsdb/db.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import logging 7 | 8 | from six import string_types 9 | 10 | from gsheetsdb.exceptions import Error, NotSupportedError, ProgrammingError 11 | from gsheetsdb.query import execute 12 | from gsheetsdb.sqlite import execute as sqlite_execute 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def connect(credentials=None): 19 | """ 20 | Constructor for creating a connection to the database. 21 | 22 | >>> conn = connect() 23 | >>> curs = conn.cursor() 24 | 25 | """ 26 | return Connection(credentials) 27 | 28 | 29 | def check_closed(f): 30 | """Decorator that checks if connection/cursor is closed.""" 31 | 32 | def g(self, *args, **kwargs): 33 | if self.closed: 34 | raise Error( 35 | '{klass} already closed'.format(klass=self.__class__.__name__)) 36 | return f(self, *args, **kwargs) 37 | return g 38 | 39 | 40 | def check_result(f): 41 | """Decorator that checks if the cursor has results from `execute`.""" 42 | 43 | def g(self, *args, **kwargs): 44 | if self._results is None: 45 | raise Error('Called before `execute`') 46 | return f(self, *args, **kwargs) 47 | return g 48 | 49 | 50 | class Connection(object): 51 | 52 | """Connection to a Google Spreadsheet.""" 53 | 54 | def __init__(self, credentials=None): 55 | self.credentials = credentials 56 | 57 | self.closed = False 58 | self.cursors = [] 59 | 60 | @check_closed 61 | def close(self): 62 | """Close the connection now.""" 63 | self.closed = True 64 | for cursor in self.cursors: 65 | try: 66 | cursor.close() 67 | except Error: 68 | pass # already closed 69 | 70 | @check_closed 71 | def commit(self): 72 | """ 73 | Commit any pending transaction to the database. 74 | 75 | Not supported. 76 | """ 77 | pass 78 | 79 | @check_closed 80 | def cursor(self): 81 | """Return a new Cursor Object using the connection.""" 82 | cursor = Cursor(self.credentials) 83 | self.cursors.append(cursor) 84 | 85 | return cursor 86 | 87 | @check_closed 88 | def execute(self, operation, parameters=None, headers=0): 89 | cursor = self.cursor() 90 | return cursor.execute(operation, parameters, headers) 91 | 92 | def __enter__(self): 93 | return self 94 | 95 | def __exit__(self, *exc): 96 | self.commit() # no-op 97 | self.close() 98 | 99 | 100 | class Cursor(object): 101 | 102 | """Connection cursor.""" 103 | 104 | def __init__(self, credentials=None): 105 | self.credentials = credentials 106 | 107 | # This read/write attribute specifies the number of rows to fetch at a 108 | # time with .fetchmany(). It defaults to 1 meaning to fetch a single 109 | # row at a time. 110 | self.arraysize = 1 111 | 112 | self.closed = False 113 | 114 | # this is updated only after a query 115 | self.description = None 116 | 117 | # this is set to a list of rows after a successful query 118 | self._results = None 119 | 120 | @property 121 | @check_result 122 | @check_closed 123 | def rowcount(self): 124 | return len(self._results) 125 | 126 | @check_closed 127 | def close(self): 128 | """Close the cursor.""" 129 | self.closed = True 130 | 131 | @check_closed 132 | def execute(self, operation, parameters=None, headers=0): 133 | self.description = None 134 | query = apply_parameters(operation, parameters or {}) 135 | try: 136 | self._results, self.description = execute( 137 | query, headers, self.credentials) 138 | except (ProgrammingError, NotSupportedError): 139 | logger.info('Query failed, running in SQLite') 140 | self._results, self.description = sqlite_execute( 141 | query, headers, self.credentials) 142 | return self 143 | 144 | @check_closed 145 | def executemany(self, operation, seq_of_parameters=None): 146 | raise NotSupportedError( 147 | '`executemany` is not supported, use `execute` instead') 148 | 149 | @check_result 150 | @check_closed 151 | def fetchone(self): 152 | """ 153 | Fetch the next row of a query result set, returning a single sequence, 154 | or `None` when no more data is available. 155 | """ 156 | try: 157 | return self._results.pop(0) 158 | except IndexError: 159 | return None 160 | 161 | @check_result 162 | @check_closed 163 | def fetchmany(self, size=None): 164 | """ 165 | Fetch the next set of rows of a query result, returning a sequence of 166 | sequences (e.g. a list of tuples). An empty sequence is returned when 167 | no more rows are available. 168 | """ 169 | size = size or self.arraysize 170 | out = self._results[:size] 171 | self._results = self._results[size:] 172 | return out 173 | 174 | @check_result 175 | @check_closed 176 | def fetchall(self): 177 | """ 178 | Fetch all (remaining) rows of a query result, returning them as a 179 | sequence of sequences (e.g. a list of tuples). Note that the cursor's 180 | arraysize attribute can affect the performance of this operation. 181 | """ 182 | out = self._results[:] 183 | self._results = [] 184 | return out 185 | 186 | @check_closed 187 | def setinputsizes(self, sizes): 188 | # not supported 189 | pass 190 | 191 | @check_closed 192 | def setoutputsizes(self, sizes): 193 | # not supported 194 | pass 195 | 196 | @check_closed 197 | def __iter__(self): 198 | return iter(self._results) 199 | 200 | 201 | def apply_parameters(operation, parameters): 202 | escaped_parameters = { 203 | key: escape(value) for key, value in parameters.items() 204 | } 205 | return operation % escaped_parameters 206 | 207 | 208 | def escape(value): 209 | if value == '*': 210 | return value 211 | elif isinstance(value, string_types): 212 | return "'{}'".format(value.replace("'", "''")) 213 | elif isinstance(value, bool): 214 | return 'TRUE' if value else 'FALSE' 215 | elif isinstance(value, (int, float)): 216 | return str(value) 217 | elif isinstance(value, (list, tuple)): 218 | return '({0})'.format(', '.join(escape(element) for element in value)) 219 | -------------------------------------------------------------------------------- /gsheetsdb/dialect.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from collections import OrderedDict 7 | 8 | from six.moves.urllib import parse 9 | from sqlalchemy.engine import default 10 | from sqlalchemy.sql import compiler 11 | from sqlalchemy import types 12 | 13 | import gsheetsdb 14 | from gsheetsdb.auth import get_credentials_from_auth 15 | 16 | 17 | type_map = { 18 | "string": types.String, 19 | "number": types.Numeric, 20 | "boolean": types.Boolean, 21 | "date": types.DATE, 22 | "datetime": types.DATETIME, 23 | "timeofday": types.TIME, 24 | } 25 | 26 | 27 | def add_headers(url, headers): 28 | parts = parse.urlparse(url) 29 | if parts.fragment.startswith("gid="): 30 | gid = parts.fragment[len("gid=") :] 31 | else: 32 | gid = 0 33 | params = parse.urlencode(OrderedDict([("headers", headers), ("gid", gid)])) 34 | return parse.urlunparse( 35 | (parts.scheme, parts.netloc, parts.path, None, params, None) 36 | ) 37 | 38 | 39 | class GSheetsIdentifierPreparer(compiler.IdentifierPreparer): 40 | # https://developers.google.com/chart/interactive/docs/querylanguage#reserved-words 41 | reserved_words = { 42 | "and", 43 | "asc", 44 | "by", 45 | "date", 46 | "datetime", 47 | "desc", 48 | "false", 49 | "format", 50 | "group", 51 | "label", 52 | "limit", 53 | "not", 54 | "offset", 55 | "options", 56 | "or", 57 | "order", 58 | "pivot", 59 | "select", 60 | "timeofday", 61 | "timestamp", 62 | "true", 63 | "where", 64 | } 65 | 66 | 67 | class GSheetsCompiler(compiler.SQLCompiler): 68 | def visit_column(self, column, **kwargs): 69 | if column.table is not None: 70 | column.table.named_with_column = False 71 | return super(GSheetsCompiler, self).visit_column(column, **kwargs) 72 | 73 | def visit_table( 74 | self, 75 | table, 76 | asfrom=False, 77 | iscrud=False, 78 | ashint=False, 79 | fromhints=None, 80 | use_schema=False, 81 | **kwargs 82 | ): 83 | return super(GSheetsCompiler, self).visit_table( 84 | table, asfrom, iscrud, ashint, fromhints, False, **kwargs 85 | ) 86 | 87 | 88 | class GSheetsTypeCompiler(compiler.GenericTypeCompiler): 89 | pass 90 | 91 | 92 | class GSheetsDialect(default.DefaultDialect): 93 | 94 | # TODO: review these 95 | # http://docs.sqlalchemy.org/en/latest/core/internals.html#sqlalchemy.engine.interfaces.Dialect 96 | name = "gsheets" 97 | scheme = "https" 98 | driver = "rest" 99 | preparer = GSheetsIdentifierPreparer 100 | statement_compiler = GSheetsCompiler 101 | type_compiler = GSheetsTypeCompiler 102 | supports_alter = False 103 | supports_pk_autoincrement = False 104 | supports_default_values = False 105 | supports_empty_insert = False 106 | supports_unicode_statements = True 107 | supports_unicode_binds = True 108 | returns_unicode_strings = True 109 | description_encoding = None 110 | supports_native_boolean = True 111 | 112 | def __init__( 113 | self, 114 | service_account_file=None, 115 | service_account_info=None, 116 | subject=None, 117 | *args, 118 | **kwargs 119 | ): 120 | super(GSheetsDialect, self).__init__(*args, **kwargs) 121 | self.credentials = get_credentials_from_auth( 122 | service_account_file, service_account_info, subject 123 | ) 124 | 125 | @classmethod 126 | def dbapi(cls): 127 | return gsheetsdb 128 | 129 | def do_ping(self, dbapi_connection): 130 | return True 131 | 132 | def create_connect_args(self, url): 133 | port = ":{url.port}".format(url=url) if url.port else "" 134 | if url.host is None: 135 | self.url = None 136 | else: 137 | self.url = "{scheme}://{host}{port}/{database}".format( 138 | scheme=self.scheme, 139 | host=url.host, 140 | port=port, 141 | database=url.database or "", 142 | ) 143 | return ([self.credentials], {}) 144 | 145 | def get_schema_names(self, connection, **kwargs): 146 | if self.url is None: 147 | return [] 148 | 149 | query = 'SELECT C, COUNT(C) FROM "{catalog}" GROUP BY C'.format( 150 | catalog=self.url 151 | ) 152 | result = connection.execute(query) 153 | return [row[0] for row in result.fetchall()] 154 | 155 | def has_table(self, connection, table_name, schema=None): 156 | if self.url is None: 157 | return True 158 | 159 | return table_name in self.get_table_names(connection, schema) 160 | 161 | def get_table_names(self, connection, schema=None, **kwargs): 162 | if self.url is None: 163 | return [] 164 | 165 | query = 'SELECT * FROM "{catalog}"'.format(catalog=self.url) 166 | if schema: 167 | query = "{query} WHERE C='{schema}'".format(query=query, schema=schema) 168 | result = connection.execute(query) 169 | return [add_headers(row[0], int(row[1])) for row in result.fetchall()] 170 | 171 | def get_view_names(self, connection, schema=None, **kwargs): 172 | return [] 173 | 174 | def get_table_options(self, connection, table_name, schema=None, **kwargs): 175 | return {} 176 | 177 | def get_columns(self, connection, table_name, schema=None, **kwargs): 178 | query = 'SELECT * FROM "{table}" LIMIT 0'.format(table=table_name) 179 | result = connection.execute(query) 180 | return [ 181 | { 182 | "name": col[0], 183 | "type": type_map[col[1].value], 184 | "nullable": True, 185 | "default": None, 186 | } 187 | for col in result._cursor_description() 188 | ] 189 | 190 | def get_pk_constraint(self, connection, table_name, schema=None, **kwargs): 191 | return {"constrained_columns": [], "name": None} 192 | 193 | def get_foreign_keys(self, connection, table_name, schema=None, **kwargs): 194 | return [] 195 | 196 | def get_check_constraints(self, connection, table_name, schema=None, **kwargs): 197 | return [] 198 | 199 | def get_table_comment(self, connection, table_name, schema=None, **kwargs): 200 | return {"text": ""} 201 | 202 | def get_indexes(self, connection, table_name, schema=None, **kwargs): 203 | return [] 204 | 205 | def get_unique_constraints(self, connection, table_name, schema=None, **kwargs): 206 | return [] 207 | 208 | def get_view_definition(self, connection, view_name, schema=None, **kwargs): 209 | pass 210 | 211 | def do_rollback(self, dbapi_connection): 212 | pass 213 | 214 | def _check_unicode_returns(self, connection, additional_tests=None): 215 | return True 216 | 217 | def _check_unicode_description(self, connection): 218 | return True 219 | -------------------------------------------------------------------------------- /gsheetsdb/exceptions.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | 7 | class Error(Exception): 8 | pass 9 | 10 | 11 | class Warning(Exception): 12 | pass 13 | 14 | 15 | class InterfaceError(Error): 16 | pass 17 | 18 | 19 | class DatabaseError(Error): 20 | pass 21 | 22 | 23 | class InternalError(DatabaseError): 24 | pass 25 | 26 | 27 | class OperationalError(DatabaseError): 28 | pass 29 | 30 | 31 | class ProgrammingError(DatabaseError): 32 | pass 33 | 34 | 35 | class IntegrityError(DatabaseError): 36 | pass 37 | 38 | 39 | class DataError(DatabaseError): 40 | pass 41 | 42 | 43 | class NotSupportedError(DatabaseError): 44 | pass 45 | -------------------------------------------------------------------------------- /gsheetsdb/formatting.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # 3 | # This Source Code Form is subject to the terms of the Mozilla Public 4 | # License, v. 2.0. If a copy of the MPL was not distributed with this file, 5 | # You can obtain one at http://mozilla.org/MPL/2.0/. 6 | # 7 | # Author: Beto Dealmeida (beto@dealmeida.net) 8 | # 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import unicode_literals 13 | 14 | import re 15 | 16 | from six import string_types, text_type 17 | 18 | from moz_sql_parser.sql_parser import RESERVED 19 | 20 | 21 | VALID = re.compile(r'[a-zA-Z_]\w*') 22 | 23 | 24 | def should_quote(identifier): 25 | r""" 26 | Return true if a given identifier should be quoted. 27 | 28 | This is usually true when the identifier: 29 | 30 | - is a reserved word 31 | - contain spaces 32 | - does not match the regex `[a-zA-Z_]\w*` 33 | 34 | """ 35 | return ( 36 | identifier != '*' and ( 37 | not VALID.match(identifier) or identifier in RESERVED)) 38 | 39 | 40 | def escape(identifier, ansi_quotes, should_quote): 41 | """ 42 | Escape identifiers. 43 | 44 | ANSI uses single quotes, but many databases use back quotes. 45 | 46 | """ 47 | if not should_quote(identifier): 48 | return identifier 49 | 50 | quote = '"' if ansi_quotes else '`' 51 | identifier = identifier.replace(quote, 2*quote) 52 | return '{0}{1}{2}'.format(quote, identifier, quote) 53 | 54 | 55 | def Operator(op, parentheses=False): 56 | op = ' {0} '.format(op) 57 | 58 | def func(self, json): 59 | out = op.join(self.dispatch(v) for v in json) 60 | if parentheses: 61 | out = '({0})'.format(out) 62 | return out 63 | 64 | return func 65 | 66 | 67 | class Formatter: 68 | 69 | clauses = [ 70 | 'select', 71 | 'from_', 72 | 'where', 73 | 'groupby', 74 | 'having', 75 | 'orderby', 76 | 'limit', 77 | 'offset', 78 | ] 79 | 80 | # simple operators 81 | _concat = Operator('||') 82 | _mult = Operator('*') 83 | _div = Operator('/', parentheses=True) 84 | _add = Operator('+') 85 | _sub = Operator('-', parentheses=True) 86 | _neq = Operator('<>') 87 | _gt = Operator('>') 88 | _lt = Operator('<') 89 | _gte = Operator('>=') 90 | _lte = Operator('<=') 91 | _eq = Operator('=') 92 | _or = Operator('OR') 93 | _and = Operator('AND') 94 | 95 | def __init__(self, ansi_quotes=True, should_quote=should_quote): 96 | self.ansi_quotes = ansi_quotes 97 | self.should_quote = should_quote 98 | 99 | def format(self, json): 100 | if 'union' in json: 101 | return self.union(json['union']) 102 | else: 103 | return self.query(json) 104 | 105 | def dispatch(self, json): 106 | if isinstance(json, list): 107 | return self.delimited_list(json) 108 | if isinstance(json, dict): 109 | if 'value' in json: 110 | return self.value(json) 111 | else: 112 | return self.op(json) 113 | if isinstance(json, string_types): 114 | return escape(json, self.ansi_quotes, self.should_quote) 115 | 116 | return text_type(json) 117 | 118 | def delimited_list(self, json): 119 | return ', '.join(self.dispatch(element) for element in json) 120 | 121 | def value(self, json): 122 | parts = [self.dispatch(json['value'])] 123 | if 'name' in json: 124 | parts.extend(['AS', self.dispatch(json['name'])]) 125 | return ' '.join(parts) 126 | 127 | def op(self, json): 128 | if 'on' in json: 129 | return self._on(json) 130 | 131 | if len(json) > 1: 132 | raise Exception('Operators should have only one key!') 133 | key, value = list(json.items())[0] 134 | 135 | # check if the attribute exists, and call the corresponding method; 136 | # note that we disallow keys that start with `_` to avoid giving access 137 | # to magic methods 138 | attr = '_{0}'.format(key) 139 | if hasattr(self, attr) and not key.startswith('_'): 140 | method = getattr(self, attr) 141 | return method(value) 142 | 143 | return '{0}({1})'.format(key.upper(), self.dispatch(value)) 144 | 145 | def _exists(self, value): 146 | return '{0} IS NOT NULL'.format(self.dispatch(value)) 147 | 148 | def _missing(self, value): 149 | return '{0} IS NULL'.format(self.dispatch(value)) 150 | 151 | def _like(self, pair): 152 | return '{0} LIKE {1}'.format( 153 | self.dispatch(pair[0]), self.dispatch(pair[1])) 154 | 155 | def _is(self, pair): 156 | return '{0} IS {1}'.format( 157 | self.dispatch(pair[0]), self.dispatch(pair[1])) 158 | 159 | def _in(self, json): 160 | valid = self.dispatch(json[1]) 161 | # `(10, 11, 12)` does not get parsed as literal, so it's formatted as 162 | # `10, 11, 12`. This fixes it. 163 | if not valid.startswith('('): 164 | valid = '({0})'.format(valid) 165 | 166 | return '{0} IN {1}'.format(json[0], valid) 167 | 168 | def _case(self, checks): 169 | parts = ['CASE'] 170 | for check in checks: 171 | if isinstance(check, dict): 172 | parts.extend(['WHEN', self.dispatch(check['when'])]) 173 | parts.extend(['THEN', self.dispatch(check['then'])]) 174 | else: 175 | parts.extend(['ELSE', self.dispatch(check)]) 176 | parts.append('END') 177 | return ' '.join(parts) 178 | 179 | def _literal(self, json): 180 | if isinstance(json, list): 181 | return '({0})'.format(', '.join(self._literal(v) for v in json)) 182 | elif isinstance(json, string_types): 183 | return "'{0}'".format(json.replace("'", "''")) 184 | else: 185 | return str(json) 186 | 187 | def _on(self, json): 188 | return 'JOIN {0} ON {1}'.format( 189 | self.dispatch(json['join']), self.dispatch(json['on'])) 190 | 191 | def union(self, json): 192 | return ' UNION '.join(self.query(query) for query in json) 193 | 194 | def query(self, json): 195 | return ' '.join( 196 | part 197 | for clause in self.clauses 198 | for part in [getattr(self, clause)(json)] 199 | if part 200 | ) 201 | 202 | def select(self, json): 203 | if 'select' in json: 204 | return 'SELECT {0}'.format(self.dispatch(json['select'])) 205 | 206 | def from_(self, json): 207 | is_join = False 208 | if 'from' in json: 209 | from_ = json['from'] 210 | if not isinstance(from_, list): 211 | from_ = [from_] 212 | 213 | parts = [] 214 | for token in from_: 215 | if 'join' in token: 216 | is_join = True 217 | parts.append(self.dispatch(token)) 218 | joiner = ' ' if is_join else ', ' 219 | rest = joiner.join(parts) 220 | return 'FROM {0}'.format(rest) 221 | 222 | def where(self, json): 223 | if 'where' in json: 224 | return 'WHERE {0}'.format(self.dispatch(json['where'])) 225 | 226 | def groupby(self, json): 227 | if 'groupby' in json: 228 | return 'GROUP BY {0}'.format(self.dispatch(json['groupby'])) 229 | 230 | def having(self, json): 231 | if 'having' in json: 232 | return 'HAVING {0}'.format(self.dispatch(json['having'])) 233 | 234 | def orderby(self, json): 235 | if 'orderby' in json: 236 | sort = json['orderby'].get('sort', '').upper() 237 | return 'ORDER BY {0} {1}'.format( 238 | self.dispatch(json['orderby']), sort).strip() 239 | 240 | def limit(self, json): 241 | if 'limit' in json: 242 | return 'LIMIT {0}'.format(self.dispatch(json['limit'])) 243 | 244 | def offset(self, json): 245 | if 'offset' in json: 246 | return 'OFFSET {0}'.format(self.dispatch(json['offset'])) 247 | 248 | 249 | def format(json, **kwargs): 250 | return Formatter(**kwargs).format(json) 251 | -------------------------------------------------------------------------------- /gsheetsdb/processors.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import warnings 7 | 8 | from six import string_types 9 | 10 | 11 | GRANULARITIES = [ 12 | 'year', 13 | 'month', 14 | 'day', 15 | 'hour', 16 | 'minute', 17 | 'second', 18 | ] 19 | 20 | LOWER_BOUNDS = [ 21 | None, # year 22 | 0, # month 23 | 1, # day 24 | 0, # hour 25 | 0, # minute 26 | 0, # second 27 | ] 28 | 29 | 30 | class Any: 31 | def __eq__(self, other): 32 | return True 33 | 34 | 35 | class OneOf: 36 | def __init__(self, valid_values): 37 | self.valid_values = valid_values 38 | 39 | def __eq__(self, other): 40 | return other in self.valid_values 41 | 42 | 43 | class JSONMatcher: 44 | 45 | def __init__(self, json=None): 46 | self.json = json 47 | 48 | def __getitem__(self, key): 49 | return self.__class__(self.json[key]) 50 | 51 | def match(self, other): 52 | raise NotImplementedError('Subclasses should implement `match`') 53 | 54 | 55 | class DummyMatcher(JSONMatcher): 56 | 57 | def match(self, other): 58 | return False 59 | 60 | 61 | class SubsetMatcher(JSONMatcher): 62 | 63 | def match(self, other): 64 | return is_subset(self.json, other) 65 | 66 | 67 | def is_subset(json, other): 68 | if isinstance(json, list): 69 | if not isinstance(other, list): 70 | return False 71 | 72 | for value in json: 73 | if not is_subset(value, other): 74 | return False 75 | return True 76 | 77 | elif isinstance(json, dict): 78 | if isinstance(other, string_types): 79 | return False 80 | elif isinstance(other, dict): 81 | other = [other] 82 | 83 | for k, v in json.items(): 84 | # each value should be a subset of the value in other 85 | for option in other: 86 | if k in option and is_subset(v, option[k]): 87 | break 88 | else: 89 | return False 90 | 91 | return True 92 | 93 | elif isinstance(other, list): 94 | return json in other 95 | 96 | else: 97 | return json == other 98 | 99 | 100 | class Processor: 101 | 102 | pattern = DummyMatcher() 103 | 104 | @classmethod 105 | def match(cls, parsed_query): 106 | return cls.pattern.match(parsed_query) 107 | 108 | def pre_process(self, parsed_query, column_map): 109 | return parsed_query 110 | 111 | def post_process(self, payload, aliases): 112 | return payload 113 | 114 | 115 | class DateTrunc(Processor): 116 | 117 | """ 118 | Implement `datetrunc` UDF. 119 | 120 | sql> SELECT time, datetrunc('month', time) FROM "http://example.com" 121 | time datetrunc-time-month 122 | ------------------- ---------------------- 123 | 2018-09-01 00:00:00 2018-09-01 00:00:00 124 | 2018-09-02 00:00:00 2018-09-01 00:00:00 125 | 2018-09-03 00:00:00 2018-09-01 00:00:00 126 | 2018-09-04 00:00:00 2018-09-01 00:00:00 127 | 2018-09-05 00:00:00 2018-09-01 00:00:00 128 | 2018-09-06 00:00:00 2018-09-01 00:00:00 129 | 2018-09-07 00:00:00 2018-09-01 00:00:00 130 | 2018-09-08 00:00:00 2018-09-01 00:00:00 131 | 2018-09-09 00:00:00 2018-09-01 00:00:00 132 | sql> 133 | 134 | This works by calling multiple time functions that extract year/month/etc, 135 | and padding the values below the requested granularity. The query above 136 | would be translated to: 137 | 138 | SELECT time, year(time), month(time) 139 | 140 | The post-processor then build the new datetime by using the year and month, 141 | and padding day/time to a lower boundary. 142 | 143 | """ 144 | 145 | pattern = SubsetMatcher( 146 | { 147 | 'select': { 148 | 'value': { 149 | 'datetrunc': [{'literal': OneOf(GRANULARITIES)}, Any()], 150 | }, 151 | }, 152 | }, 153 | ) 154 | 155 | def pre_process(self, parsed_query, column_map): 156 | select = parsed_query['select'] 157 | if not isinstance(select, list): 158 | select = [select] 159 | 160 | # match select 161 | new_select = [] 162 | matcher = self.pattern['select'] 163 | self.new_columns = [] 164 | for i, expr in enumerate(select): 165 | if matcher.match(expr): 166 | alias = expr.get('name') 167 | self.new_columns.append((i, alias, expr)) 168 | new_select.extend(self.get_columns(expr)) 169 | else: 170 | new_select.append(expr) 171 | 172 | # remove duplicates 173 | seen = set() 174 | deduped_select = [] 175 | for expr in new_select: 176 | value = expr['value'] 177 | if isinstance(value, dict): 178 | value = tuple(value.items()) 179 | if value not in seen: 180 | seen.add(value) 181 | deduped_select.append(expr) 182 | 183 | # remove columns from group by 184 | groupby = parsed_query.get('groupby') 185 | if groupby: 186 | new_groupby = [] 187 | matcher = SubsetMatcher({'value': { 188 | 'datetrunc': [{'literal': OneOf(GRANULARITIES)}, Any()]}, 189 | }) 190 | if not isinstance(groupby, list): 191 | groupby = [groupby] 192 | for expr in groupby: 193 | if matcher.match(expr): 194 | new_groupby.extend( 195 | self.get_columns(expr, alias_column=False)) 196 | else: 197 | new_groupby.append(expr) 198 | 199 | # remove duplicates 200 | seen = set() 201 | deduped_groupby = [] 202 | for expr in new_groupby: 203 | value = expr['value'] 204 | if isinstance(value, dict): 205 | value = tuple(value.items()) 206 | if value not in seen: 207 | seen.add(value) 208 | deduped_groupby.append(expr) 209 | 210 | parsed_query['groupby'] = deduped_groupby 211 | 212 | parsed_query['select'] = deduped_select 213 | return parsed_query 214 | 215 | def post_process(self, payload, aliases): 216 | added_columns = [ 217 | alias and 218 | alias.startswith('__{0}__'.format(self.__class__.__name__)) 219 | for alias in aliases 220 | ] 221 | 222 | cols = payload['table']['cols'] 223 | payload['table']['cols'] = [ 224 | col for (col, added) in zip(cols, added_columns) if not added 225 | ] 226 | 227 | for position, alias, expr in self.new_columns: 228 | id_ = 'datetrunc-{name}-{granularity}'.format( 229 | name=expr['value']['datetrunc'][1], 230 | granularity=expr['value']['datetrunc'][0]['literal'], 231 | ) 232 | payload['table']['cols'].insert( 233 | position, 234 | {'id': id_, 'label': alias or id_, 'type': 'datetime'}) 235 | 236 | for row in payload['table']['rows']: 237 | row_c = row['c'] 238 | row['c'] = [ 239 | value for (value, added) in zip(row['c'], added_columns) 240 | if not added 241 | ] 242 | for position, alias, expr in self.new_columns: 243 | row['c'].insert( 244 | position, {'v': self.get_value(cols, row_c, expr)}) 245 | 246 | return payload 247 | 248 | def get_value(self, cols, row_c, expr): 249 | """ 250 | Build the datetime from individual columns. 251 | 252 | """ 253 | name = expr['value']['datetrunc'][1] 254 | granularity = expr['value']['datetrunc'][0]['literal'] 255 | i = GRANULARITIES.index(granularity) 256 | 257 | # map function to index 258 | labels = [col['label'] for col in cols] 259 | values = [] 260 | for func_name in GRANULARITIES: 261 | label = '{0}({1})'.format(func_name, name) 262 | if label in labels: 263 | values.append(row_c[labels.index(label)]['v']) 264 | 265 | # truncate values to requested granularity and pad with lower bounds 266 | args = values[:i+1] 267 | args += LOWER_BOUNDS[len(args):] 268 | args = [str(int(arg)) for arg in args] 269 | 270 | return 'Date({0})'.format(','.join(args)) 271 | 272 | def get_columns(self, expr, alias_column=True): 273 | """ 274 | Get all columns required to compute a given granularity. 275 | 276 | """ 277 | name = expr['value']['datetrunc'][1] 278 | granularity = expr['value']['datetrunc'][0]['literal'] 279 | 280 | for func_name in GRANULARITIES: 281 | alias = '__{namespace}__{func_name}__{name}'.format( 282 | namespace=self.__class__.__name__, 283 | func_name=func_name, 284 | name=name, 285 | ) 286 | column = {'value': {func_name: name}} 287 | if alias_column: 288 | column['name'] = alias 289 | yield column 290 | if func_name == granularity: 291 | break 292 | 293 | 294 | class CountStar(Processor): 295 | 296 | pattern = SubsetMatcher({'select': {'value': {'count': '*'}}}) 297 | 298 | def pre_process(self, parsed_query, column_map): 299 | warnings.warn( 300 | 'COUNT(*) only works if at least one column has no NULLs') 301 | 302 | select = parsed_query['select'] 303 | if not isinstance(select, list): 304 | select = [select] 305 | 306 | new_select = [] 307 | matcher = self.pattern['select'] 308 | self.new_columns = [] 309 | for i, expr in enumerate(select): 310 | if matcher.match(expr): 311 | alias = expr.get('name', 'count star') 312 | self.new_columns.append((i, alias, expr)) 313 | else: 314 | new_select.append(expr) 315 | 316 | # count each column 317 | for label in column_map: 318 | alias = '__{namespace}__{label}'.format( 319 | namespace=self.__class__.__name__, label=label) 320 | new_select.append({'value': {'count': label}, 'name': alias}) 321 | 322 | parsed_query['select'] = new_select 323 | return parsed_query 324 | 325 | def post_process(self, payload, aliases): 326 | added_columns = [ 327 | alias and 328 | alias.startswith('__{0}__'.format(self.__class__.__name__)) 329 | for alias in aliases 330 | ] 331 | 332 | payload['table']['cols'] = [ 333 | col for (col, added) 334 | in zip(payload['table']['cols'], added_columns) 335 | if not added 336 | ] 337 | 338 | position, alias, expr = self.new_columns[0] 339 | payload['table']['cols'].insert( 340 | position, {'id': 'count-star', 'label': alias, 'type': 'number'}) 341 | 342 | for row in payload['table']['rows']: 343 | values = [ 344 | value['v'] for (value, added) in zip(row['c'], added_columns) 345 | if added 346 | ] 347 | count_star = max(values) 348 | row['c'] = [ 349 | value for (value, added) in zip(row['c'], added_columns) 350 | if not added 351 | ] 352 | row['c'].insert(position, {'v': count_star}) 353 | 354 | # the API returns no rows when the count is zero 355 | if not payload['table']['rows']: 356 | payload['table']['rows'].append({'c': [{'v': 0}]}) 357 | 358 | return payload 359 | 360 | 361 | processors = [CountStar, DateTrunc] 362 | -------------------------------------------------------------------------------- /gsheetsdb/query.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from collections import OrderedDict 7 | import json 8 | import logging 9 | 10 | from google.auth.transport.requests import AuthorizedSession 11 | from moz_sql_parser import parse as parse_sql 12 | import pyparsing 13 | from requests import Session 14 | from six.moves.urllib import parse 15 | 16 | from gsheetsdb.convert import convert_rows 17 | from gsheetsdb.exceptions import InterfaceError, ProgrammingError 18 | from gsheetsdb.processors import processors 19 | from gsheetsdb.translator import extract_column_aliases, translate 20 | from gsheetsdb.types import Type 21 | from gsheetsdb.url import extract_url, get_url 22 | from gsheetsdb.utils import format_gsheet_error, format_moz_error 23 | 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | # the JSON payload has this in the beginning 28 | LEADING = ")]}'\n" 29 | 30 | 31 | def get_column_map(url, credentials=None): 32 | query = 'SELECT * LIMIT 0' 33 | result = run_query(url, query, credentials) 34 | return OrderedDict( 35 | sorted((col['label'], col['id']) for col in result['table']['cols'])) 36 | 37 | 38 | def run_query(baseurl, query, credentials=None): 39 | url = '{baseurl}&tq={query}'.format( 40 | baseurl=baseurl, query=parse.quote(query, safe='/()')) 41 | headers = {'X-DataSource-Auth': 'true'} 42 | 43 | if credentials: 44 | session = AuthorizedSession(credentials) 45 | else: 46 | session = Session() 47 | 48 | r = session.get(url, headers=headers) 49 | if r.encoding is None: 50 | r.encoding = 'utf-8' 51 | 52 | # raise any error messages 53 | if r.status_code != 200: 54 | raise ProgrammingError(r.text) 55 | 56 | if r.text.startswith(LEADING): 57 | result = json.loads(r.text[len(LEADING):]) 58 | else: 59 | result = r.json() 60 | 61 | return result 62 | 63 | 64 | def get_description_from_payload(payload): 65 | """ 66 | Return description from a single row. 67 | 68 | We only return the name, type (inferred from the data) and if the values 69 | can be NULL. String columns in Druid are NULLable. Numeric columns are NOT 70 | NULL. 71 | """ 72 | return [ 73 | ( 74 | col['label'], # name 75 | Type[col['type'].upper()], # type_code 76 | None, # [display_size] 77 | None, # [internal_size] 78 | None, # [precision] 79 | None, # [scale] 80 | True, # [null_ok] 81 | ) 82 | for col in payload['table']['cols'] 83 | ] 84 | 85 | 86 | def execute(query, headers=0, credentials=None): 87 | try: 88 | parsed_query = parse_sql(query) 89 | except pyparsing.ParseException as e: 90 | raise ProgrammingError(format_moz_error(query, e)) 91 | 92 | # fetch aliases, since they will be removed by the translator 93 | original_aliases = extract_column_aliases(parsed_query) 94 | 95 | # extract URL from the `FROM` clause 96 | from_ = extract_url(query) 97 | baseurl = get_url(from_, headers) 98 | 99 | # verify that URL is actually a Google spreadsheet 100 | parsed = parse.urlparse(baseurl) 101 | if not parsed.netloc == 'docs.google.com': 102 | raise InterfaceError('Invalid URL, must be a docs.google.com URL!') 103 | 104 | # map between labels and ids, eg, `{ 'country': 'A' }` 105 | column_map = get_column_map(baseurl, credentials) 106 | 107 | # preprocess 108 | used_processors = [] 109 | for cls in processors: 110 | if cls.match(parsed_query): 111 | processor = cls() 112 | parsed_query = processor.pre_process(parsed_query, column_map) 113 | used_processors.append(processor) 114 | processed_aliases = extract_column_aliases(parsed_query) 115 | 116 | # translate colum names to ids and remove aliases 117 | translated_query = translate(parsed_query, column_map) 118 | logger.info('Original query: {}'.format(query)) 119 | logger.info('Translated query: {}'.format(translated_query)) 120 | 121 | # run query 122 | payload = run_query(baseurl, translated_query, credentials) 123 | if payload['status'] == 'error': 124 | raise ProgrammingError( 125 | format_gsheet_error(query, translated_query, payload['errors'])) 126 | 127 | # postprocess 128 | for processor in used_processors: 129 | payload = processor.post_process(payload, processed_aliases) 130 | 131 | # add aliases back 132 | cols = payload['table']['cols'] 133 | for alias, col in zip(original_aliases, cols): 134 | if alias is not None: 135 | col['label'] = alias 136 | 137 | # remove columns with no label/name 138 | cols = [col for col in cols if col['label']] 139 | payload['table']['cols'] = cols 140 | 141 | description = get_description_from_payload(payload) 142 | 143 | # convert rows to proper type (datetime, eg) 144 | rows = payload['table']['rows'] 145 | results = convert_rows(cols, rows) 146 | 147 | return results, description 148 | -------------------------------------------------------------------------------- /gsheetsdb/sqlite.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import datetime 6 | import logging 7 | import sqlite3 8 | 9 | from gsheetsdb.convert import convert_rows 10 | from gsheetsdb.exceptions import ProgrammingError 11 | from gsheetsdb.query import run_query 12 | from gsheetsdb.url import extract_url, get_url 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | # Google Spreadsheet types to SQLite types 18 | typemap = { 19 | 'string': 'text', 20 | 'number': 'real', 21 | 'boolean': 'boolean', 22 | 'date': 'date', 23 | 'datetime': 'timestamp', 24 | 'timeofday': 'timeofday', 25 | } 26 | 27 | 28 | def adapt_timeofday(timeofday): 29 | return ( 30 | 3600e6 * timeofday.hour + 31 | 60e6 * timeofday.minute + 32 | 1e6 * timeofday.second + 33 | timeofday.microsecond 34 | ) 35 | 36 | 37 | def convert_timeofday(val): 38 | val = int(val) 39 | hour, val = divmod(val, int(3600e6)) 40 | minute, val = divmod(val, int(60e6)) 41 | second, val = divmod(val, int(1e6)) 42 | microsecond = int(val) 43 | return datetime.time(hour, minute, second, microsecond) 44 | 45 | 46 | sqlite3.register_adapter(datetime.time, adapt_timeofday) 47 | sqlite3.register_converter('timeofday', convert_timeofday) 48 | 49 | 50 | def create_table(cursor, table, payload): 51 | cols = ', '.join( 52 | '"{name}" {type}'.format( 53 | name=col['label'] or col['id'], type=typemap[col['type']]) 54 | for col in payload['table']['cols'] 55 | ) 56 | query = 'CREATE TABLE "{table}" ({cols})'.format(table=table, cols=cols) 57 | logger.info(query) 58 | cursor.execute(query) 59 | 60 | 61 | def insert_into(cursor, table, payload): 62 | cols = payload['table']['cols'] 63 | values = ', '.join('?' for col in cols) 64 | query = 'INSERT INTO "{table}" VALUES ({values})'.format( 65 | table=table, values=values) 66 | rows = convert_rows(cols, payload['table']['rows']) 67 | logger.info(query) 68 | cursor.executemany(query, rows) 69 | 70 | 71 | def execute(query, headers=0, credentials=None): 72 | # fetch all the data 73 | from_ = extract_url(query) 74 | if not from_: 75 | raise ProgrammingError('Invalid query: {query}'.format(query=query)) 76 | baseurl = get_url(from_, headers) 77 | payload = run_query(baseurl, 'SELECT *', credentials) 78 | 79 | # create table 80 | conn = sqlite3.connect(':memory:', detect_types=sqlite3.PARSE_DECLTYPES) 81 | cursor = conn.cursor() 82 | create_table(cursor, from_, payload) 83 | insert_into(cursor, from_, payload) 84 | conn.commit() 85 | 86 | # run query in SQLite instead 87 | logger.info('SQLite query: {}'.format(query)) 88 | results = cursor.execute(query).fetchall() 89 | description = cursor.description 90 | 91 | return results, description 92 | -------------------------------------------------------------------------------- /gsheetsdb/translator.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | try: 7 | from moz_sql_parser import format 8 | except ImportError: # pragma: no cover 9 | from gsheetsdb.formatting import format 10 | from six import string_types 11 | 12 | from gsheetsdb.exceptions import NotSupportedError 13 | 14 | 15 | def replace(obj, replacements): 16 | """ 17 | Modify parsed query recursively in place. 18 | 19 | """ 20 | if isinstance(obj, list): 21 | for i, value in enumerate(obj): 22 | if isinstance(value, string_types) and value in replacements: 23 | obj[i] = replacements[value] 24 | elif isinstance(value, (list, dict)): 25 | replace(value, replacements) 26 | elif isinstance(obj, dict): 27 | for key, value in obj.items(): 28 | if isinstance(value, string_types) and value in replacements: 29 | obj[key] = replacements[value] 30 | elif isinstance(value, list): 31 | replace(value, replacements) 32 | elif isinstance(value, dict) and 'literal' not in value: 33 | replace(value, replacements) 34 | 35 | 36 | def remove_aliases(parsed_query): 37 | select = parsed_query['select'] 38 | if isinstance(select, dict): 39 | select = [select] 40 | 41 | for clause in select: 42 | if 'name' in clause: 43 | del clause['name'] 44 | 45 | 46 | def unalias_orderby(parsed_query): 47 | if 'orderby' not in parsed_query: 48 | return 49 | 50 | select = parsed_query['select'] 51 | if isinstance(select, dict): 52 | select = [select] 53 | 54 | alias_to_value = { 55 | clause['name']: clause['value'] 56 | for clause in select 57 | if isinstance(clause, dict) and 'name' in clause 58 | } 59 | 60 | for k, v in parsed_query['orderby'].items(): 61 | if isinstance(v, string_types) and v in alias_to_value: 62 | parsed_query['orderby'][k] = alias_to_value[v] 63 | 64 | 65 | def extract_column_aliases(parsed_query): 66 | select = parsed_query['select'] 67 | if isinstance(select, dict): 68 | select = [select] 69 | 70 | aliases = [] 71 | for clause in select: 72 | if isinstance(clause, dict): 73 | aliases.append(clause.get('name')) 74 | else: 75 | aliases.append(None) 76 | 77 | return aliases 78 | 79 | 80 | def translate(parsed_query, column_map=None): 81 | if column_map is None: 82 | column_map = {} 83 | 84 | # HAVING is not supported 85 | if 'having' in parsed_query: 86 | raise NotSupportedError('HAVING not supported') 87 | 88 | from_ = parsed_query.pop('from') 89 | if not isinstance(from_, string_types): 90 | raise NotSupportedError('FROM should be a URL') 91 | 92 | unalias_orderby(parsed_query) 93 | remove_aliases(parsed_query) 94 | replace(parsed_query, column_map) 95 | 96 | return format(parsed_query) 97 | -------------------------------------------------------------------------------- /gsheetsdb/types.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from enum import Enum 7 | 8 | 9 | class Type(Enum): 10 | STRING = 'string' 11 | NUMBER = 'number' 12 | BOOLEAN = 'boolean' 13 | DATE = 'date' 14 | DATETIME = 'datetime' 15 | TIMEOFDAY = 'timeofday' 16 | -------------------------------------------------------------------------------- /gsheetsdb/url.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from collections import OrderedDict 7 | 8 | from moz_sql_parser import parse as parse_sql 9 | import pyparsing 10 | import re 11 | from six.moves.urllib import parse 12 | from gsheetsdb.exceptions import ProgrammingError 13 | 14 | 15 | FROM_REGEX = re.compile(' from ("http.*?")', re.IGNORECASE) 16 | 17 | 18 | def get_url(url, headers=0, gid=0, sheet=None): 19 | parts = parse.urlparse(url) 20 | if parts.path.endswith('/edit'): 21 | path = parts.path[:-len('/edit')] 22 | else: 23 | path = parts.path 24 | path = '/'.join((path.rstrip('/'), 'gviz/tq')) 25 | 26 | qs = parse.parse_qs(parts.query) 27 | if 'headers' in qs: 28 | headers = int(qs['headers'][-1]) 29 | if 'gid' in qs: 30 | gid = qs['gid'][-1] 31 | if 'sheet' in qs: 32 | sheet = qs['sheet'][-1] 33 | 34 | if parts.fragment.startswith('gid='): 35 | gid = parts.fragment[len('gid='):] 36 | 37 | args = OrderedDict() 38 | if headers > 0: 39 | args['headers'] = headers 40 | if sheet is not None: 41 | args['sheet'] = sheet 42 | else: 43 | args['gid'] = gid 44 | params = parse.urlencode(args) 45 | 46 | netloc = parts.netloc.replace("\.", ".") 47 | 48 | return parse.urlunparse( 49 | (parts.scheme, netloc, path, None, params, None)) 50 | 51 | 52 | def extract_url(sql): 53 | try: 54 | url = parse_sql(sql)['from'] 55 | except pyparsing.ParseException: 56 | # fallback to regex to extract from 57 | match = FROM_REGEX.search(sql) 58 | if match: 59 | return match.group(1).strip('"') 60 | return 61 | 62 | while isinstance(url, dict): 63 | url = url['value']['from'] 64 | 65 | return url 66 | -------------------------------------------------------------------------------- /gsheetsdb/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import re 7 | 8 | 9 | POSITION = re.compile(r'at line (?P\d+), column (?P\d+)') 10 | 11 | 12 | def format_moz_error(query, exception): 13 | """ 14 | Format syntax error when parsing the original query. 15 | 16 | """ 17 | line = exception.lineno 18 | column = exception.col 19 | detailed_message = str(exception) 20 | 21 | msg = query.split('\n')[:line] 22 | msg.append('{indent}^'.format(indent=' ' * (column - 1))) 23 | msg.append(detailed_message) 24 | return '\n'.join(msg) 25 | 26 | 27 | def format_gsheet_error(query, translated_query, errors): 28 | """ 29 | Format syntax error returned by API when parsing translated query. 30 | 31 | """ 32 | error_messages = [] 33 | for error in errors: 34 | detailed_message = error['detailed_message'] 35 | match = POSITION.search(detailed_message) 36 | if match: 37 | groups = match.groupdict() 38 | line = int(groups['line']) 39 | column = int(groups['column']) 40 | 41 | msg = translated_query.split('\n')[:line] 42 | msg.append('{indent}^'.format(indent=' ' * (column - 1))) 43 | msg.append(detailed_message) 44 | error_messages.append('\n'.join(msg)) 45 | else: 46 | error_messages.append(detailed_message) 47 | 48 | return """ 49 | Original query: 50 | {query} 51 | 52 | Translated query: 53 | {translated_query} 54 | 55 | Error{plural}: 56 | {error_messages} 57 | """.format( 58 | query=query, 59 | translated_query=translated_query, 60 | plural='s' if len(errors) > 1 else '', 61 | error_messages='\n'.join(error_messages), 62 | ).strip() 63 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrs==20.2.0 2 | bleach==3.2.1 3 | cachetools==4.1.1 4 | certifi==2020.6.20 5 | chardet==3.0.4 6 | colorama==0.4.3 7 | coverage==5.3 8 | docopt==0.6.2 9 | docutils==0.16 10 | flake8==3.8.3 11 | google-auth==1.21.2 12 | idna==2.10 13 | iniconfig==1.0.1 14 | keyring==21.4.0 15 | mccabe==0.6.1 16 | mo-future==3.89.20246 17 | more-itertools==8.5.0 18 | moz-sql-parser==3.32.20026 19 | nose==1.3.7 20 | packaging==20.4 21 | pipreqs==0.4.10 22 | pkginfo==1.5.0.1 23 | pluggy==0.13.1 24 | prompt-toolkit==3.0.7 25 | py==1.9.0 26 | pyasn1==0.4.8 27 | pyasn1-modules==0.2.8 28 | pycodestyle==2.6.0 29 | pyflakes==2.2.0 30 | Pygments==2.7.1 31 | pyparsing==2.3.1 32 | pytest==6.0.2 33 | pytest-cov==2.10.1 34 | readme-renderer==26.0 35 | requests==2.24.0 36 | requests-mock==1.8.0 37 | requests-toolbelt==0.9.1 38 | rfc3986==1.4.0 39 | rsa==4.6 40 | six==1.15.0 41 | SQLAlchemy==1.3.19 42 | tabulate==0.8.7 43 | toml==0.10.1 44 | tqdm==4.49.0 45 | twine==3.2.0 46 | urllib3==1.25.10 47 | wcwidth==0.2.5 48 | webencodings==0.5.1 49 | yarg==0.1.9 50 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [nosetests] 2 | match=^test 3 | nocapture=1 4 | with-coverage=1 5 | cover-package=gsheetsdb 6 | cover-erase=1 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Note: To use the 'upload' functionality of this file, you must: 5 | # $ pip install twine 6 | 7 | import os 8 | import sys 9 | from shutil import rmtree 10 | 11 | from setuptools import find_packages, setup, Command 12 | 13 | # Package meta-data. 14 | NAME = 'gsheetsdb' 15 | DESCRIPTION = 'Python DB-API and SQLAlchemy interface for Google Spreadsheets.' 16 | URL = 'https://github.com/betodealmeida/gsheets-db-api' 17 | EMAIL = 'beto@dealmeida.net' 18 | AUTHOR = 'Beto Dealmeida' 19 | 20 | # What packages are required for this module to be executed? 21 | REQUIRED = [ 22 | 'google-auth', 23 | 'moz_sql_parser', 24 | 'requests>=2.20.0', 25 | 'six', 26 | ] 27 | if sys.version_info < (3, 4): 28 | REQUIRED.append('enum34') 29 | 30 | sqlalchemy_extras = [ 31 | 'sqlalchemy', 32 | ] 33 | 34 | cli_extras = [ 35 | 'docopt', 36 | 'pygments', 37 | 'prompt_toolkit>=2', 38 | 'tabulate', 39 | ] 40 | 41 | development_extras = [ 42 | 'coverage', 43 | 'flake8', 44 | 'nose', 45 | 'pipreqs', 46 | 'pytest>=2.8', 47 | 'pytest-cov', 48 | 'requests_mock', 49 | 'twine', 50 | ] 51 | if sys.version_info < (3, 3): 52 | development_extras.append('mock') 53 | 54 | 55 | # The rest you shouldn't have to touch too much :) 56 | # ------------------------------------------------ 57 | # Except, perhaps the License and Trove Classifiers! 58 | # If you do change the License, remember to change the Trove Classifier for 59 | # that! 60 | 61 | here = os.path.abspath(os.path.dirname(__file__)) 62 | 63 | long_description = '' 64 | 65 | # Load the package's __version__.py module as a dictionary. 66 | about = {} 67 | with open(os.path.join(here, NAME, '__version__.py')) as f: 68 | exec(f.read(), about) 69 | 70 | 71 | class UploadCommand(Command): 72 | """Support setup.py upload.""" 73 | 74 | description = 'Build and publish the package.' 75 | user_options = [] 76 | 77 | @staticmethod 78 | def status(s): 79 | """Prints things in bold.""" 80 | print('\033[1m{0}\033[0m'.format(s)) 81 | 82 | def initialize_options(self): 83 | pass 84 | 85 | def finalize_options(self): 86 | pass 87 | 88 | def run(self): 89 | try: 90 | self.status('Removing previous builds…') 91 | rmtree(os.path.join(here, 'dist')) 92 | except OSError: 93 | pass 94 | 95 | self.status('Building Source and Wheel (universal) distribution…') 96 | os.system( 97 | '{0} setup.py sdist bdist_wheel --universal'.format( 98 | sys.executable)) 99 | 100 | self.status('Uploading the package to PyPi via Twine…') 101 | os.system('twine upload dist/*') 102 | 103 | sys.exit() 104 | 105 | 106 | # Where the magic happens: 107 | setup( 108 | name=NAME, 109 | version=about['__version__'], 110 | description=DESCRIPTION, 111 | long_description=long_description, 112 | author=AUTHOR, 113 | author_email=EMAIL, 114 | url=URL, 115 | packages=find_packages(exclude=('tests',)), 116 | # If your package is a single module, use this instead of 'packages': 117 | # py_modules=['mypackage'], 118 | 119 | entry_points={ 120 | 'console_scripts': [ 121 | 'gsheetsdb = gsheetsdb.console:main', 122 | ], 123 | 'sqlalchemy.dialects': [ 124 | 'gsheets = gsheetsdb.dialect:GSheetsDialect', 125 | ], 126 | }, 127 | install_requires=REQUIRED, 128 | extras_require={ 129 | 'cli': cli_extras, 130 | 'dev': development_extras, 131 | 'sqlalchemy': sqlalchemy_extras, 132 | }, 133 | include_package_data=True, 134 | license='MIT', 135 | classifiers=[ 136 | # Trove classifiers 137 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 138 | 'License :: OSI Approved :: MIT License', 139 | 'Programming Language :: Python', 140 | 'Programming Language :: Python :: 3.4', 141 | 'Programming Language :: Python :: 3.5', 142 | 'Programming Language :: Python :: 3.6', 143 | 'Programming Language :: Python :: 3.7', 144 | 'Programming Language :: Python :: Implementation :: CPython', 145 | 'Programming Language :: Python :: Implementation :: PyPy' 146 | ], 147 | # $ setup.py publish support. 148 | cmdclass={ 149 | 'upload': UploadCommand, 150 | }, 151 | ) 152 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/betodealmeida/gsheets-db-api/f023b32986d4da9a501fca8d435f2b6edc153353/tests/__init__.py -------------------------------------------------------------------------------- /tests/context.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import os 3 | import sys 4 | sys.path.insert(0, os.path.abspath( 5 | os.path.join(os.path.dirname(__file__), '..'))) 6 | 7 | import gsheetsdb 8 | from gsheetsdb import console 9 | from gsheetsdb import exceptions 10 | from gsheetsdb.convert import convert_rows 11 | from gsheetsdb.db import ( 12 | apply_parameters, 13 | Connection, 14 | check_closed, 15 | connect, 16 | check_result, 17 | ) 18 | from gsheetsdb.dialect import add_headers, GSheetsDialect 19 | from gsheetsdb.processors import ( 20 | Any, 21 | CountStar, 22 | DateTrunc, 23 | is_subset, 24 | Processor, 25 | SubsetMatcher, 26 | ) 27 | from gsheetsdb.query import ( 28 | execute, 29 | get_column_map, 30 | get_description_from_payload, 31 | LEADING, 32 | run_query, 33 | ) 34 | from gsheetsdb.translator import extract_column_aliases, translate 35 | from gsheetsdb.types import Type 36 | from gsheetsdb.utils import format_gsheet_error, format_moz_error 37 | from gsheetsdb.url import extract_url, get_url 38 | -------------------------------------------------------------------------------- /tests/test_console.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | try: 4 | from unittest.mock import patch 5 | except ImportError: 6 | from mock import patch 7 | 8 | import unittest 9 | 10 | import requests_mock 11 | from six import StringIO 12 | 13 | from .context import console, exceptions 14 | 15 | 16 | class ConsoleTestSuite(unittest.TestCase): 17 | 18 | @patch('gsheetsdb.console.docopt') 19 | @patch('sys.stdout', new_callable=StringIO) 20 | @patch('gsheetsdb.console.prompt') 21 | def test_main(self, prompt, stdout, docopt): 22 | docopt.return_value = { 23 | '--headers': '0', 24 | '--raise': False, 25 | '--service-account-file': None, 26 | '--subject': None, 27 | } 28 | prompt.side_effect = EOFError() 29 | console.main() 30 | self.assertEqual(stdout.getvalue(), 'See ya!\n') 31 | 32 | @patch('gsheetsdb.console.docopt') 33 | @requests_mock.Mocker() 34 | @patch('sys.stdout', new_callable=StringIO) 35 | @patch('gsheetsdb.console.prompt') 36 | def test_main_query(self, m, prompt, stdout, docopt): 37 | docopt.return_value = { 38 | '--headers': '0', 39 | '--raise': False, 40 | '--service-account-file': None, 41 | '--subject': None, 42 | } 43 | header_payload = { 44 | 'table': { 45 | 'cols': [ 46 | {'id': 'A', 'label': 'country', 'type': 'string'}, 47 | { 48 | 'id': 'B', 49 | 'label': 'cnt', 50 | 'type': 'number', 51 | 'pattern': 'General', 52 | }, 53 | ], 54 | }, 55 | } 56 | query_payload = { 57 | 'status': 'ok', 58 | 'table': { 59 | 'cols': [ 60 | {'id': 'A', 'label': 'country', 'type': 'string'}, 61 | { 62 | 'id': 'B', 63 | 'label': 'cnt', 64 | 'type': 'number', 65 | 'pattern': 'General', 66 | }, 67 | ], 68 | 'rows': [ 69 | {'c': [{'v': 'BR'}, {'v': 1.0, 'f': '1'}]}, 70 | {'c': [{'v': 'IN'}, {'v': 2.0, 'f': '2'}]}, 71 | ], 72 | }, 73 | } 74 | m.get( 75 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A%20LIMIT%200', 76 | json=header_payload, 77 | ) 78 | m.get( 79 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A', 80 | json=query_payload, 81 | ) 82 | 83 | def gen(): 84 | yield 'SELECT * FROM "http://docs.google.com/"' 85 | raise EOFError() 86 | 87 | prompt.side_effect = gen() 88 | console.main() 89 | result = stdout.getvalue() 90 | expected = ( 91 | 'country cnt\n' 92 | '--------- -----\n' 93 | 'BR 1\n' 94 | 'IN 2\n' 95 | 'See ya!\n' 96 | ) 97 | self.assertEqual(result, expected) 98 | 99 | @patch('gsheetsdb.console.docopt') 100 | @patch('sys.stdout', new_callable=StringIO) 101 | @patch('gsheetsdb.console.prompt') 102 | def test_console_exception(self, prompt, stdout, docopt): 103 | docopt.return_value = { 104 | '--headers': '0', 105 | '--raise': False, 106 | '--service-account-file': None, 107 | '--subject': None, 108 | } 109 | 110 | def gen(): 111 | yield 'SELECTSELECTSELECT' 112 | raise EOFError() 113 | 114 | prompt.side_effect = gen() 115 | console.main() 116 | result = stdout.getvalue() 117 | expected = ( 118 | 'Invalid query: SELECTSELECTSELECT\n' 119 | 'See ya!\n' 120 | ) 121 | self.assertEqual(result, expected) 122 | 123 | @patch('gsheetsdb.console.docopt') 124 | @patch('sys.stdout', new_callable=StringIO) 125 | @patch('gsheetsdb.console.prompt') 126 | def test_console_raise_exception(self, prompt, stdout, docopt): 127 | docopt.return_value = { 128 | '--headers': '0', 129 | '--raise': True, 130 | '--service-account-file': None, 131 | '--subject': None, 132 | } 133 | 134 | def gen(): 135 | yield 'SELECTSELECTSELECT' 136 | raise EOFError() 137 | 138 | prompt.side_effect = gen() 139 | with self.assertRaises(exceptions.ProgrammingError): 140 | console.main() 141 | -------------------------------------------------------------------------------- /tests/test_convert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from collections import namedtuple 4 | import datetime 5 | import unittest 6 | 7 | from .context import convert_rows 8 | 9 | 10 | class ConvertTestSuite(unittest.TestCase): 11 | 12 | payload = { 13 | "version": "0.6", 14 | "reqId": "0", 15 | "status": "ok", 16 | "sig": "1788543417", 17 | "table": { 18 | "cols": [ 19 | { 20 | "id": "A", 21 | "label": "datetime", 22 | "type": "datetime", 23 | "pattern": "M/d/yyyy H:mm:ss", 24 | }, 25 | { 26 | "id": "B", 27 | "label": "number", 28 | "type": "number", 29 | "pattern": "General", 30 | }, 31 | { 32 | "id": "C", 33 | "label": "boolean", 34 | "type": "boolean", 35 | }, 36 | { 37 | "id": "D", 38 | "label": "date", 39 | "type": "date", 40 | "pattern": "M/d/yyyy", 41 | }, 42 | { 43 | "id": "E", 44 | "label": "timeofday", 45 | "type": "timeofday", 46 | "pattern": "h:mm:ss am/pm", 47 | }, 48 | { 49 | "id": "F", 50 | "label": "string", 51 | "type": "string", 52 | }, 53 | ], 54 | "rows": [ 55 | { 56 | "c": [ 57 | {"v": "Date(2018,8,1,0,0,0)", "f": "9/1/2018 0:00:00"}, 58 | {"v": 1.0, "f": "1"}, 59 | {"v": True, "f": "TRUE"}, 60 | {"v": "Date(2018,0,1)", "f": "1/1/2018"}, 61 | {"v": [17, 0, 0, 0], "f": "5:00:00 PM"}, 62 | {"v": "test"}, 63 | ], 64 | }, 65 | { 66 | "c": [ 67 | None, 68 | {"v": 1.0, "f": "1"}, 69 | {"v": True, "f": "TRUE"}, 70 | None, 71 | None, 72 | {"v": "test"}, 73 | ], 74 | }, 75 | ], 76 | }, 77 | } 78 | 79 | def test_convert(self): 80 | cols = self.payload['table']['cols'] 81 | rows = self.payload['table']['rows'] 82 | result = convert_rows(cols, rows) 83 | Row = namedtuple( 84 | 'Row', 'datetime number boolean date timeofday string') 85 | expected = [ 86 | Row( 87 | datetime=datetime.datetime(2018, 9, 1, 0, 0), 88 | number=1.0, 89 | boolean=True, 90 | date=datetime.date(2018, 1, 1), 91 | timeofday=datetime.time(17, 0), 92 | string='test', 93 | ), 94 | Row( 95 | datetime=None, 96 | number=1.0, 97 | boolean=True, 98 | date=None, 99 | timeofday=None, 100 | string='test', 101 | ), 102 | ] 103 | self.assertEqual(result, expected) 104 | -------------------------------------------------------------------------------- /tests/test_db.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from collections import namedtuple 4 | import unittest 5 | 6 | import requests_mock 7 | 8 | from .context import ( 9 | apply_parameters, 10 | Connection, 11 | connect, 12 | exceptions, 13 | ) 14 | 15 | 16 | class DBTestSuite(unittest.TestCase): 17 | 18 | header_payload = { 19 | 'table': { 20 | 'cols': [ 21 | {'id': 'A', 'label': 'country', 'type': 'string'}, 22 | { 23 | 'id': 'B', 24 | 'label': 'cnt', 25 | 'type': 'number', 26 | 'pattern': 'General', 27 | }, 28 | ], 29 | }, 30 | } 31 | 32 | query_payload = { 33 | 'status': 'ok', 34 | 'table': { 35 | 'cols': [ 36 | {'id': 'A', 'label': 'country', 'type': 'string'}, 37 | { 38 | 'id': 'B', 39 | 'label': 'cnt', 40 | 'type': 'number', 41 | 'pattern': 'General', 42 | }, 43 | ], 44 | 'rows': [ 45 | {'c': [{'v': 'BR'}, {'v': 1.0, 'f': '1'}]}, 46 | {'c': [{'v': 'IN'}, {'v': 2.0, 'f': '2'}]}, 47 | ], 48 | }, 49 | } 50 | 51 | def test_connection(self): 52 | conn = connect() 53 | self.assertFalse(conn.closed) 54 | self.assertEqual(conn.cursors, []) 55 | 56 | def test_check_closed(self): 57 | conn = connect() 58 | conn.close() 59 | 60 | with self.assertRaises(exceptions.Error): 61 | conn.close() 62 | 63 | def test_close_cursors(self): 64 | conn = connect() 65 | cursor1 = conn.cursor() 66 | cursor2 = conn.cursor() 67 | cursor2.close() 68 | 69 | conn.close() 70 | 71 | self.assertTrue(cursor1.closed) 72 | self.assertTrue(cursor2.closed) 73 | 74 | def test_commit(self): 75 | conn = connect() 76 | conn.commit() # no-op 77 | 78 | @requests_mock.Mocker() 79 | def test_connection_execute(self, m): 80 | m.get( 81 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A%20LIMIT%200', 82 | json=self.header_payload, 83 | ) 84 | m.get( 85 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A', 86 | json=self.query_payload, 87 | ) 88 | 89 | with Connection() as conn: 90 | result = conn.execute( 91 | 'SELECT * FROM "http://docs.google.com/"').fetchall() 92 | Row = namedtuple('Row', 'country cnt') 93 | expected = [Row(country=u'BR', cnt=1.0), Row(country=u'IN', cnt=2.0)] 94 | self.assertEqual(result, expected) 95 | 96 | @requests_mock.Mocker() 97 | def test_cursor_execute(self, m): 98 | m.get( 99 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A%20LIMIT%200', 100 | json=self.header_payload, 101 | ) 102 | m.get( 103 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A', 104 | json=self.query_payload, 105 | ) 106 | 107 | with Connection() as conn: 108 | cursor = conn.cursor() 109 | result = cursor.execute( 110 | 'SELECT * FROM "http://docs.google.com/"').fetchall() 111 | Row = namedtuple('Row', 'country cnt') 112 | expected = [Row(country=u'BR', cnt=1.0), Row(country=u'IN', cnt=2.0)] 113 | self.assertEqual(result, expected) 114 | 115 | def test_cursor_executemany(self): 116 | conn = Connection() 117 | cursor = conn.cursor() 118 | with self.assertRaises(exceptions.NotSupportedError): 119 | cursor.executemany('SELECT * FROM "http://docs.google.com/"') 120 | 121 | @requests_mock.Mocker() 122 | def test_cursor(self, m): 123 | m.get( 124 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A%20LIMIT%200', 125 | json=self.header_payload, 126 | ) 127 | m.get( 128 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A', 129 | json=self.query_payload, 130 | ) 131 | 132 | conn = Connection() 133 | cursor = conn.cursor() 134 | cursor.setinputsizes(0) # no-op 135 | cursor.setoutputsizes(0) # no-op 136 | 137 | @requests_mock.Mocker() 138 | def test_cursor_rowcount(self, m): 139 | m.get( 140 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A%20LIMIT%200', 141 | json=self.header_payload, 142 | ) 143 | m.get( 144 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A', 145 | json=self.query_payload, 146 | ) 147 | 148 | conn = Connection() 149 | cursor = conn.cursor() 150 | 151 | with self.assertRaises(exceptions.Error): 152 | cursor.rowcount() 153 | 154 | cursor.execute('SELECT * FROM "http://docs.google.com/"') 155 | self.assertEqual(cursor.rowcount, 2) 156 | 157 | @requests_mock.Mocker() 158 | def test_cursor_fetchone(self, m): 159 | m.get( 160 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A%20LIMIT%200', 161 | json=self.header_payload, 162 | ) 163 | m.get( 164 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A', 165 | json=self.query_payload, 166 | ) 167 | 168 | conn = Connection() 169 | cursor = conn.cursor() 170 | cursor.execute('SELECT * FROM "http://docs.google.com/"') 171 | Row = namedtuple('Row', 'country cnt') 172 | 173 | self.assertEqual(cursor.fetchone(), Row(country=u'BR', cnt=1.0)) 174 | self.assertEqual(cursor.fetchone(), Row(country=u'IN', cnt=2.0)) 175 | self.assertIsNone(cursor.fetchone()) 176 | 177 | @requests_mock.Mocker() 178 | def test_cursor_fetchall(self, m): 179 | m.get( 180 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A%20LIMIT%200', 181 | json=self.header_payload, 182 | ) 183 | m.get( 184 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A', 185 | json=self.query_payload, 186 | ) 187 | 188 | conn = Connection() 189 | cursor = conn.cursor() 190 | cursor.execute('SELECT * FROM "http://docs.google.com/"') 191 | Row = namedtuple('Row', 'country cnt') 192 | 193 | self.assertEqual(cursor.fetchone(), Row(country=u'BR', cnt=1.0)) 194 | self.assertEqual(cursor.fetchall(), [Row(country=u'IN', cnt=2.0)]) 195 | 196 | @requests_mock.Mocker() 197 | def test_cursor_fetchmany(self, m): 198 | m.get( 199 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A%20LIMIT%200', 200 | json=self.header_payload, 201 | ) 202 | m.get( 203 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A', 204 | json=self.query_payload, 205 | ) 206 | 207 | conn = Connection() 208 | cursor = conn.cursor() 209 | cursor.execute('SELECT * FROM "http://docs.google.com/"') 210 | Row = namedtuple('Row', 'country cnt') 211 | 212 | self.assertEqual(cursor.fetchmany(1), [Row(country=u'BR', cnt=1.0)]) 213 | self.assertEqual(cursor.fetchmany(10), [Row(country=u'IN', cnt=2.0)]) 214 | self.assertEqual(cursor.fetchmany(100), []) 215 | 216 | @requests_mock.Mocker() 217 | def test_cursor_iter(self, m): 218 | m.get( 219 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A%20LIMIT%200', 220 | json=self.header_payload, 221 | ) 222 | m.get( 223 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A', 224 | json=self.query_payload, 225 | ) 226 | 227 | conn = Connection() 228 | cursor = conn.cursor() 229 | cursor.execute('SELECT * FROM "http://docs.google.com/"') 230 | Row = namedtuple('Row', 'country cnt') 231 | 232 | self.assertEqual( 233 | list(cursor), 234 | [Row(country=u'BR', cnt=1.0), Row(country=u'IN', cnt=2.0)], 235 | ) 236 | 237 | def test_apply_parameters(self): 238 | query = 'SELECT * FROM table WHERE name=%(name)s' 239 | parameters = {'name': 'Alice'} 240 | result = apply_parameters(query, parameters) 241 | expected = "SELECT * FROM table WHERE name='Alice'" 242 | self.assertEqual(result, expected) 243 | 244 | def test_apply_parameters_escape(self): 245 | query = 'SELECT * FROM table WHERE name=%(name)s' 246 | parameters = {'name': "O'Malley's"} 247 | result = apply_parameters(query, parameters) 248 | expected = "SELECT * FROM table WHERE name='O''Malley''s'" 249 | self.assertEqual(result, expected) 250 | 251 | def test_apply_parameters_float(self): 252 | query = 'SELECT * FROM table WHERE age=%(age)s' 253 | parameters = {'age': 50} 254 | result = apply_parameters(query, parameters) 255 | expected = "SELECT * FROM table WHERE age=50" 256 | self.assertEqual(result, expected) 257 | 258 | def test_apply_parameters_bool(self): 259 | query = 'SELECT * FROM table WHERE active=%(active)s' 260 | parameters = {'active': True} 261 | result = apply_parameters(query, parameters) 262 | expected = "SELECT * FROM table WHERE active=TRUE" 263 | self.assertEqual(result, expected) 264 | 265 | def test_apply_parameters_list(self): 266 | query = ( 267 | 'SELECT * FROM table ' 268 | 'WHERE id IN %(allowed)s ' 269 | 'AND id NOT IN %(prohibited)s' 270 | ) 271 | parameters = {'allowed': [1, 2], 'prohibited': (2, 3)} 272 | result = apply_parameters(query, parameters) 273 | expected = ( 274 | 'SELECT * FROM table ' 275 | 'WHERE id IN (1, 2) ' 276 | 'AND id NOT IN (2, 3)' 277 | ) 278 | self.assertEqual(result, expected) 279 | 280 | def test_apply_parameters_star(self): 281 | query = 'SELECT %(column)s FROM table' 282 | parameters = {'column': '*'} 283 | result = apply_parameters(query, parameters) 284 | expected = "SELECT * FROM table" 285 | self.assertEqual(result, expected) 286 | -------------------------------------------------------------------------------- /tests/test_dialect.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | try: 4 | from unittest.mock import Mock 5 | except ImportError: 6 | from mock import Mock 7 | 8 | import unittest 9 | 10 | import requests_mock 11 | from sqlalchemy import MetaData, select, Table 12 | from sqlalchemy.engine import create_engine 13 | from sqlalchemy.engine.url import make_url 14 | from sqlalchemy.sql import sqltypes 15 | 16 | from .context import add_headers, connect, gsheetsdb, GSheetsDialect, Type 17 | 18 | 19 | class DialectTestSuite(unittest.TestCase): 20 | 21 | def test_add_headers(self): 22 | url = 'http://docs.google.com/' 23 | headers = 10 24 | result = add_headers(url, headers) 25 | expected = 'http://docs.google.com/?headers=10&gid=0' 26 | self.assertEqual(result, expected) 27 | 28 | def test_add_headers_with_gid(self): 29 | url = 'http://docs.google.com/#gid=10' 30 | headers = 10 31 | result = add_headers(url, headers) 32 | expected = 'http://docs.google.com/?headers=10&gid=10' 33 | self.assertEqual(result, expected) 34 | 35 | def test_cls_dbapi(self): 36 | self.assertEqual(GSheetsDialect.dbapi(), gsheetsdb) 37 | 38 | def test_create_connect_args(self): 39 | dialect = GSheetsDialect() 40 | 41 | url = make_url('gsheets://') 42 | args = dialect.create_connect_args(url) 43 | self.assertEqual(args, ([None], {})) 44 | self.assertIsNone(dialect.url) 45 | 46 | url = make_url('gsheets://docs.google.com/') 47 | args = dialect.create_connect_args(url) 48 | self.assertEqual(args, ([None], {})) 49 | self.assertEqual( 50 | dialect.url, '{0}://docs.google.com/'.format(dialect.scheme)) 51 | 52 | def test_get_schema_names(self): 53 | connection = Mock() 54 | connection.execute = Mock() 55 | result = Mock() 56 | result.fetchall = Mock() 57 | result.fetchall.return_value = [('default', 2), ('public', 4)] 58 | connection.execute.return_value = result 59 | 60 | dialect = GSheetsDialect() 61 | url = make_url('gsheets://docs.google.com/') 62 | dialect.create_connect_args(url) 63 | result = dialect.get_schema_names(connection) 64 | expected = ['default', 'public'] 65 | self.assertEqual(result, expected) 66 | 67 | def test_get_schema_names_no_catalog(self): 68 | connection = connect() 69 | dialect = GSheetsDialect() 70 | url = make_url('gsheets://') 71 | dialect.create_connect_args(url) 72 | result = dialect.get_schema_names(connection) 73 | expected = [] 74 | self.assertEqual(result, expected) 75 | 76 | def test_get_table_names(self): 77 | connection = Mock() 78 | connection.execute = Mock() 79 | result = Mock() 80 | result.fetchall = Mock() 81 | result.fetchall.return_value = [ 82 | ('http://docs.google.com/edit#gid=0', 2), 83 | ('http://docs.google.com/edit#gid=1', 1), 84 | ] 85 | connection.execute.return_value = result 86 | 87 | dialect = GSheetsDialect() 88 | url = make_url('gsheets://docs.google.com/') 89 | dialect.create_connect_args(url) 90 | result = dialect.get_table_names(connection) 91 | expected = [ 92 | 'http://docs.google.com/edit?headers=2&gid=0', 93 | 'http://docs.google.com/edit?headers=1&gid=1', 94 | ] 95 | self.assertEqual(result, expected) 96 | 97 | def test_get_table_names_no_catalog(self): 98 | connection = connect() 99 | dialect = GSheetsDialect() 100 | url = make_url('gsheets://') 101 | dialect.create_connect_args(url) 102 | result = dialect.get_table_names(connection) 103 | expected = [] 104 | self.assertEqual(result, expected) 105 | 106 | def test_has_table(self): 107 | connection = Mock() 108 | connection.execute = Mock() 109 | result = Mock() 110 | result.fetchall = Mock() 111 | result.fetchall.return_value = [ 112 | ('http://docs.google.com/edit#gid=0', 2), 113 | ('http://docs.google.com/edit#gid=1', 1), 114 | ] 115 | connection.execute.return_value = result 116 | 117 | dialect = GSheetsDialect() 118 | url = make_url('gsheets://docs.google.com/') 119 | dialect.create_connect_args(url) 120 | 121 | self.assertTrue( 122 | dialect.has_table( 123 | connection, 'http://docs.google.com/edit?headers=2&gid=0')) 124 | self.assertFalse( 125 | dialect.has_table( 126 | connection, 'http://docs.google.com/edit?headers=2&gid=1')) 127 | 128 | def test_has_table_no_catalog(self): 129 | connection = connect() 130 | dialect = GSheetsDialect() 131 | url = make_url('gsheets://') 132 | dialect.create_connect_args(url) 133 | self.assertTrue(dialect.has_table(connection, 'ANY TABLE')) 134 | 135 | def test_get_columns(self): 136 | description = [ 137 | ('datetime', Type.DATETIME, None, None, None, None, True), 138 | ('number', Type.NUMBER, None, None, None, None, True), 139 | ('boolean', Type.BOOLEAN, None, None, None, None, True), 140 | ('date', Type.DATE, None, None, None, None, True), 141 | ('timeofday', Type.TIMEOFDAY, None, None, None, None, True), 142 | ('string', Type.STRING, None, None, None, None, True), 143 | ] 144 | connection = Mock() 145 | connection.execute = Mock() 146 | result = Mock() 147 | result._cursor_description = Mock() 148 | result._cursor_description.return_value = description 149 | connection.execute.return_value = result 150 | 151 | dialect = GSheetsDialect() 152 | url = make_url('gsheets://docs.google.com/') 153 | dialect.create_connect_args(url) 154 | 155 | result = dialect.get_columns(connection, 'SOME TABLE') 156 | expected = [ 157 | { 158 | 'name': 'datetime', 159 | 'type': sqltypes.DATETIME, 160 | 'nullable': True, 161 | 'default': None, 162 | }, 163 | { 164 | 'name': 'number', 165 | 'type': sqltypes.Numeric, 166 | 'nullable': True, 167 | 'default': None, 168 | }, 169 | { 170 | 'name': 'boolean', 171 | 'type': sqltypes.Boolean, 172 | 'nullable': True, 173 | 'default': None, 174 | }, 175 | { 176 | 'name': 'date', 177 | 'type': sqltypes.DATE, 178 | 'nullable': True, 179 | 'default': None, 180 | }, 181 | { 182 | 'name': 'timeofday', 183 | 'type': sqltypes.TIME, 184 | 'nullable': True, 185 | 'default': None, 186 | }, 187 | { 188 | 'name': 'string', 189 | 'type': sqltypes.String, 190 | 'nullable': True, 191 | 'default': None, 192 | }, 193 | ] 194 | self.assertEqual(result, expected) 195 | 196 | def test_get_view_names(self): 197 | connection = connect() 198 | dialect = GSheetsDialect() 199 | result = dialect.get_view_names(connection) 200 | expected = [] 201 | self.assertEqual(result, expected) 202 | 203 | def test_get_table_options(self): 204 | connection = connect() 205 | table_name = 'http://docs.google.com/' 206 | dialect = GSheetsDialect() 207 | result = dialect.get_table_options(connection, table_name) 208 | expected = {} 209 | self.assertEqual(result, expected) 210 | 211 | def test_get_pk_constraint(self): 212 | connection = connect() 213 | table_name = 'http://docs.google.com/' 214 | dialect = GSheetsDialect() 215 | result = dialect.get_pk_constraint(connection, table_name) 216 | expected = {'constrained_columns': [], 'name': None} 217 | self.assertEqual(result, expected) 218 | 219 | def test_get_foreign_keys(self): 220 | connection = connect() 221 | table_name = 'http://docs.google.com/' 222 | dialect = GSheetsDialect() 223 | result = dialect.get_foreign_keys(connection, table_name) 224 | expected = [] 225 | self.assertEqual(result, expected) 226 | 227 | def test_get_check_constraints(self): 228 | connection = connect() 229 | table_name = 'http://docs.google.com/' 230 | dialect = GSheetsDialect() 231 | result = dialect.get_check_constraints(connection, table_name) 232 | expected = [] 233 | self.assertEqual(result, expected) 234 | 235 | def test_get_table_comment(self): 236 | connection = connect() 237 | table_name = 'http://docs.google.com/' 238 | dialect = GSheetsDialect() 239 | result = dialect.get_table_comment(connection, table_name) 240 | expected = {'text': ''} 241 | self.assertEqual(result, expected) 242 | 243 | def test_get_indexes(self): 244 | connection = connect() 245 | table_name = 'http://docs.google.com/' 246 | dialect = GSheetsDialect() 247 | result = dialect.get_indexes(connection, table_name) 248 | expected = [] 249 | self.assertEqual(result, expected) 250 | 251 | def test_get_unique_constraints(self): 252 | connection = connect() 253 | table_name = 'http://docs.google.com/' 254 | dialect = GSheetsDialect() 255 | result = dialect.get_unique_constraints(connection, table_name) 256 | expected = [] 257 | self.assertEqual(result, expected) 258 | 259 | def test_get_view_definition(self): 260 | connection = connect() 261 | view_name = 'http://docs.google.com/' 262 | dialect = GSheetsDialect() 263 | result = dialect.get_view_definition(connection, view_name) 264 | self.assertIsNone(result) 265 | 266 | def test_do_rollback(self): 267 | connection = connect() 268 | dialect = GSheetsDialect() 269 | result = dialect.do_rollback(connection) 270 | self.assertIsNone(result) 271 | 272 | def test__check_unicode_returns(self): 273 | connection = connect() 274 | dialect = GSheetsDialect() 275 | result = dialect._check_unicode_returns(connection) 276 | self.assertTrue(result) 277 | 278 | def test__check_unicode_description(self): 279 | connection = connect() 280 | dialect = GSheetsDialect() 281 | result = dialect._check_unicode_description(connection) 282 | self.assertTrue(result) 283 | 284 | @requests_mock.Mocker() 285 | def test_GSheetsCompiler(self, m): 286 | header_payload = { 287 | 'status': 'ok', 288 | 'table': { 289 | 'cols': [ 290 | {'id': 'A', 'label': 'country', 'type': 'string'}, 291 | { 292 | 'id': 'B', 293 | 'label': 'cnt', 294 | 'type': 'number', 295 | 'pattern': 'General', 296 | }, 297 | ], 298 | 'rows': [], 299 | }, 300 | } 301 | m.get( 302 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A%20LIMIT%200', 303 | json=header_payload, 304 | ) 305 | m.get( 306 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A', 307 | json=header_payload, 308 | ) 309 | engine = create_engine('gsheets://') 310 | table = Table( 311 | 'http://docs.google.com/', MetaData(bind=engine), autoload=True) 312 | query = select([table.columns.country], from_obj=table) 313 | result = str(query) 314 | expected = 'SELECT country \nFROM "http://docs.google.com/"' 315 | self.assertEqual(result, expected) 316 | -------------------------------------------------------------------------------- /tests/test_processing.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from collections import OrderedDict 4 | import unittest 5 | import warnings 6 | 7 | from moz_sql_parser import parse 8 | 9 | from .context import ( 10 | Any, 11 | CountStar, 12 | DateTrunc, 13 | is_subset, 14 | Processor, 15 | SubsetMatcher, 16 | ) 17 | 18 | 19 | class ProcessingTestSuite(unittest.TestCase): 20 | 21 | def test_processor(self): 22 | processor = Processor() 23 | sql = 'SELECT * FROM "http://docs.google.com"' 24 | parsed_query = parse(sql) 25 | payload = {} 26 | 27 | # a base processor never matches, and doesn't do anything 28 | self.assertFalse(processor.match({})) 29 | self.assertEqual(parsed_query, processor.pre_process(parsed_query, {})) 30 | self.assertEqual(payload, processor.post_process(payload, [None])) 31 | 32 | def test_count_star(self): 33 | sql = 'SELECT COUNT(*) AS total FROM "http://docs.google.com"' 34 | parsed_query = parse(sql) 35 | column_map = OrderedDict(sorted({'country': 'A', 'cnt': 'B'}.items())) 36 | 37 | self.assertTrue(CountStar.match(parsed_query)) 38 | 39 | processor = CountStar() 40 | with warnings.catch_warnings(): 41 | warnings.simplefilter("ignore") 42 | result = processor.pre_process(parsed_query, column_map) 43 | expected = parse(''' 44 | SELECT 45 | COUNT(cnt) AS __CountStar__cnt 46 | , COUNT(country) AS __CountStar__country 47 | FROM 48 | "http://docs.google.com" 49 | ''') 50 | self.assertEqual(result, expected) 51 | 52 | payload = { 53 | 'status': 'ok', 54 | 'table': { 55 | 'cols': [ 56 | { 57 | 'id': 'count-A', 58 | 'label': 'count country', 59 | 'type': 'number', 60 | }, 61 | {'id': 'count-B', 'label': 'count cnt', 'type': 'number'}, 62 | ], 63 | 'rows': [ 64 | { 65 | 'c': [ 66 | {'v': 9.0}, 67 | {'v': 8.0}, 68 | ], 69 | }, 70 | ], 71 | }, 72 | } 73 | aliases = ['__CountStar__country', '__CountStar__cnt'] 74 | result = processor.post_process(payload, aliases) 75 | expected = { 76 | 'status': 'ok', 77 | 'table': { 78 | 'cols': [ 79 | {'id': 'count-star', 'label': 'total', 'type': 'number'}, 80 | ], 81 | 'rows': [{'c': [{'v': 9.0}]}], 82 | }, 83 | } 84 | self.assertEqual(result, expected) 85 | 86 | def test_count_star_no_results(self): 87 | sql = 'SELECT COUNT(*) AS total FROM "http://docs.google.com"' 88 | parsed_query = parse(sql) 89 | column_map = OrderedDict(sorted({'country': 'A', 'cnt': 'B'}.items())) 90 | 91 | processor = CountStar() 92 | with warnings.catch_warnings(): 93 | warnings.simplefilter("ignore") 94 | processor.pre_process(parsed_query, column_map) 95 | 96 | payload = { 97 | 'status': 'ok', 98 | 'table': { 99 | 'cols': [ 100 | { 101 | 'id': 'count-A', 102 | 'label': 'count country', 103 | 'type': 'number', 104 | }, 105 | {'id': 'count-B', 'label': 'count cnt', 'type': 'number'}, 106 | ], 107 | 'rows': [], 108 | }, 109 | } 110 | aliases = ['__CountStar__country', '__CountStar__cnt'] 111 | result = processor.post_process(payload, aliases) 112 | expected = { 113 | 'status': 'ok', 114 | 'table': { 115 | 'cols': [ 116 | {'id': 'count-star', 'label': 'total', 'type': 'number'}, 117 | ], 118 | 'rows': [{'c': [{'v': 0}]}], 119 | }, 120 | } 121 | self.assertEqual(result, expected) 122 | 123 | def test_count_star_with_groupby(self): 124 | sql = ( 125 | 'SELECT country, COUNT(*) FROM "http://docs.google.com" ' 126 | 'GROUP BY country' 127 | ) 128 | parsed_query = parse(sql) 129 | column_map = OrderedDict(sorted({'country': 'A', 'cnt': 'B'}.items())) 130 | 131 | self.assertTrue(CountStar.match(parsed_query)) 132 | 133 | processor = CountStar() 134 | 135 | with warnings.catch_warnings(): 136 | warnings.simplefilter("ignore") 137 | result = processor.pre_process(parsed_query, column_map) 138 | expected = parse(''' 139 | SELECT 140 | country 141 | , COUNT(cnt) AS __CountStar__cnt 142 | , COUNT(country) AS __CountStar__country 143 | FROM 144 | "http://docs.google.com" 145 | GROUP BY 146 | country 147 | ''') 148 | self.assertEqual(result, expected) 149 | 150 | payload = { 151 | 'status': 'ok', 152 | 'table': { 153 | 'cols': [ 154 | {'id': 'A', 'label': 'country', 'type': 'string'}, 155 | { 156 | 'id': 'count-B', 157 | 'label': 'count country', 158 | 'type': 'number', 159 | }, 160 | {'id': 'count-C', 'label': 'count cnt', 'type': 'number'}, 161 | ], 162 | 'rows': [ 163 | {'c': [{'v': 'BR'}, {'v': 4.0}, {'v': 3.0}]}, 164 | {'c': [{'v': 'IN'}, {'v': 5.0}, {'v': 1.0}]}, 165 | ], 166 | }, 167 | } 168 | aliases = ['country', '__CountStar__country', '__CountStar__cnt'] 169 | result = processor.post_process(payload, aliases) 170 | expected = { 171 | 'status': 'ok', 172 | 'table': { 173 | 'cols': [ 174 | {'id': 'A', 'label': 'country', 'type': 'string'}, 175 | { 176 | 'id': 'count-star', 177 | 'label': 'count star', 178 | 'type': 'number', 179 | }, 180 | ], 181 | 'rows': [ 182 | {'c': [{'v': 'BR'}, {'v': 4.0}]}, 183 | {'c': [{'v': 'IN'}, {'v': 5.0}]}, 184 | ], 185 | }, 186 | } 187 | self.assertEqual(result, expected) 188 | 189 | def test_subset_matcher(self): 190 | pattern = SubsetMatcher({'select': {'value': {'count': '*'}}}) 191 | 192 | parsed_query = parse('SELECT COUNT(*)') 193 | self.assertTrue(pattern.match(parsed_query)) 194 | 195 | parsed_query = parse('SELECT COUNT(*) AS total') 196 | self.assertTrue(pattern.match(parsed_query)) 197 | 198 | parsed_query = parse('SELECT COUNT(*) AS total, country') 199 | self.assertTrue(pattern.match(parsed_query)) 200 | 201 | parsed_query = parse('SELECT cnt, COUNT(*) AS total, country') 202 | self.assertTrue(pattern.match(parsed_query)) 203 | 204 | parsed_query = parse('SELECT country') 205 | self.assertFalse(pattern.match(parsed_query)) 206 | 207 | parsed_query = parse( 208 | 'SELECT country, COUNT(*) FROM "http://docs.google.com" ' 209 | 'GROUP BY country') 210 | self.assertTrue(pattern.match(parsed_query)) 211 | 212 | def test_is_subset(self): 213 | json = [1, 2, 3] 214 | 215 | other = [1, 2, 3, 4] 216 | self.assertTrue(is_subset(json, other)) 217 | 218 | other = 1 219 | self.assertFalse(is_subset(json, other)) 220 | 221 | other = [1, 3, 4] 222 | self.assertFalse(is_subset(json, other)) 223 | 224 | def test_any_match(self): 225 | pattern = SubsetMatcher({'name': Any()}) 226 | self.assertTrue(pattern.match({'name': 'Alice'})) 227 | self.assertTrue(pattern.match({'name': 'Bob'})) 228 | 229 | def test_datetrunc_match(self): 230 | self.assertTrue( 231 | DateTrunc.match({ 232 | 'select': { 233 | 'value': { 234 | 'datetrunc': [ 235 | {'literal': 'month'}, 'datetime', 236 | ], 237 | }, 238 | }, 239 | }) 240 | ) 241 | self.assertFalse( 242 | DateTrunc.match({ 243 | 'select': { 244 | 'value': { 245 | 'datetrunc': [{'literal': 'week'}, 'datetime'], 246 | }, 247 | }, 248 | }) 249 | ) 250 | self.assertFalse(DateTrunc.match({'select': {'value': 'datetime'}})) 251 | self.assertFalse( 252 | DateTrunc.match({'select': {'value': {'year': 'datetime'}}})) 253 | self.assertFalse( 254 | DateTrunc.match({'select': {'value': {'month': 'datetime'}}})) 255 | -------------------------------------------------------------------------------- /tests/test_query.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | try: 4 | from unittest.mock import Mock 5 | except ImportError: 6 | from mock import Mock 7 | 8 | from collections import namedtuple 9 | import unittest 10 | 11 | import requests_mock 12 | from six import BytesIO 13 | from urllib3.response import HTTPResponse 14 | 15 | from .context import ( 16 | exceptions, 17 | execute, 18 | get_column_map, 19 | get_description_from_payload, 20 | LEADING, 21 | run_query, 22 | Type, 23 | ) 24 | 25 | 26 | class QueryTestSuite(unittest.TestCase): 27 | 28 | @requests_mock.Mocker() 29 | def test_get_column_map(self, m): 30 | payload = { 31 | 'table': { 32 | 'cols': [ 33 | {'id': 'A', 'label': 'country', 'type': 'string'}, 34 | { 35 | 'id': 'B', 36 | 'label': 'cnt', 37 | 'type': 'number', 38 | 'pattern': 'General', 39 | }, 40 | ], 41 | }, 42 | } 43 | m.get('http://docs.google.com/&tq=SELECT%20%2A%20LIMIT%200', 44 | json=payload) 45 | 46 | url = 'http://docs.google.com/' 47 | result = get_column_map(url) 48 | expected = {'country': 'A', 'cnt': 'B'} 49 | self.assertEqual(result, expected) 50 | 51 | @requests_mock.Mocker() 52 | def test_run_query(self, m): 53 | m.get('http://docs.google.com/&tq=SELECT%20%2A', json='ok') 54 | 55 | baseurl = 'http://docs.google.com/' 56 | query = 'SELECT *' 57 | result = run_query(baseurl, query) 58 | expected = 'ok' 59 | self.assertEqual(result, expected) 60 | 61 | @requests_mock.Mocker() 62 | def test_run_query_with_credentials(self, m): 63 | m.get('http://docs.google.com/&tq=SELECT%20%2A', json='ok') 64 | credentials = Mock() 65 | credentials.before_request = Mock() 66 | 67 | baseurl = 'http://docs.google.com/' 68 | query = 'SELECT *' 69 | result = run_query(baseurl, query, credentials) 70 | expected = 'ok' 71 | self.assertEqual(result, expected) 72 | 73 | @requests_mock.Mocker() 74 | def test_run_query_error(self, m): 75 | m.get( 76 | 'http://docs.google.com/&tq=SELECT%20%2A', 77 | text='Error', 78 | status_code=500, 79 | ) 80 | 81 | baseurl = 'http://docs.google.com/' 82 | query = 'SELECT *' 83 | with self.assertRaises(exceptions.ProgrammingError): 84 | run_query(baseurl, query) 85 | 86 | @requests_mock.Mocker() 87 | def test_run_query_leading(self, m): 88 | text = '{0}{1}'.format(LEADING, '"ok"') 89 | m.get('http://docs.google.com/&tq=SELECT%20%2A', text=text) 90 | 91 | baseurl = 'http://docs.google.com/' 92 | query = 'SELECT *' 93 | result = run_query(baseurl, query) 94 | expected = 'ok' 95 | self.assertEqual(result, expected) 96 | 97 | @requests_mock.Mocker() 98 | def test_run_query_no_encoding(self, m): 99 | raw = HTTPResponse( 100 | body=BytesIO('"ok"'.encode('utf-8')), 101 | preload_content=False, 102 | headers={ 103 | 'Content-type': 'application/json', 104 | }, 105 | status=200, 106 | ) 107 | m.get('http://docs.google.com/&tq=SELECT%20%2A', raw=raw) 108 | 109 | baseurl = 'http://docs.google.com/' 110 | query = 'SELECT *' 111 | result = run_query(baseurl, query) 112 | expected = 'ok' 113 | self.assertEqual(result, expected) 114 | 115 | def test_get_description_from_payload(self): 116 | payload = { 117 | 'table': { 118 | 'cols': [ 119 | { 120 | 'id': 'A', 121 | 'label': 'datetime', 122 | 'type': 'datetime', 123 | 'pattern': 'M/d/yyyy H:mm:ss', 124 | }, 125 | { 126 | 'id': 'B', 127 | 'label': 'number', 128 | 'type': 'number', 129 | 'pattern': 'General', 130 | }, 131 | {'id': 'C', 'label': 'boolean', 'type': 'boolean'}, 132 | { 133 | 'id': 'D', 134 | 'label': 'date', 135 | 'type': 'date', 136 | 'pattern': 'M/d/yyyy', 137 | }, 138 | { 139 | 'id': 'E', 140 | 'label': 'timeofday', 141 | 'type': 'timeofday', 142 | 'pattern': 'h:mm:ss am/pm', 143 | }, 144 | {'id': 'F', 'label': 'string', 'type': 'string'}, 145 | ], 146 | }, 147 | } 148 | result = get_description_from_payload(payload) 149 | expected = [ 150 | ('datetime', Type.DATETIME, None, None, None, None, True), 151 | ('number', Type.NUMBER, None, None, None, None, True), 152 | ('boolean', Type.BOOLEAN, None, None, None, None, True), 153 | ('date', Type.DATE, None, None, None, None, True), 154 | ('timeofday', Type.TIMEOFDAY, None, None, None, None, True), 155 | ('string', Type.STRING, None, None, None, None, True), 156 | ] 157 | self.assertEqual(result, expected) 158 | 159 | @requests_mock.Mocker() 160 | def test_execute(self, m): 161 | header_payload = { 162 | 'table': { 163 | 'cols': [ 164 | {'id': 'A', 'label': 'country', 'type': 'string'}, 165 | { 166 | 'id': 'B', 167 | 'label': 'cnt', 168 | 'type': 'number', 169 | 'pattern': 'General', 170 | }, 171 | ], 172 | }, 173 | } 174 | m.get( 175 | 'http://docs.google.com/gviz/tq?headers=1&gid=0&' 176 | 'tq=SELECT%20%2A%20LIMIT%200', 177 | json=header_payload, 178 | ) 179 | query_payload = { 180 | 'status': 'ok', 181 | 'table': { 182 | 'cols': [ 183 | {'id': 'A', 'label': 'country', 'type': 'string'}, 184 | { 185 | 'id': 'B', 186 | 'label': 'cnt', 187 | 'type': 'number', 188 | 'pattern': 'General', 189 | }, 190 | ], 191 | 'rows': [{'c': [{'v': 'BR'}, {'v': 1.0, 'f': '1'}]}], 192 | }, 193 | } 194 | m.get( 195 | 'http://docs.google.com/gviz/tq?headers=1&gid=0&tq=SELECT%20%2A', 196 | json=query_payload, 197 | ) 198 | 199 | query = 'SELECT * FROM "http://docs.google.com/"' 200 | headers = 1 201 | results, description = execute(query, headers) 202 | Row = namedtuple('Row', 'country cnt') 203 | self.assertEqual(results, [Row(country=u'BR', cnt=1.0)]) 204 | self.assertEqual( 205 | description, 206 | [ 207 | ('country', Type.STRING, None, None, None, None, True), 208 | ('cnt', Type.NUMBER, None, None, None, None, True), 209 | ], 210 | ) 211 | 212 | @requests_mock.Mocker() 213 | def test_execute_with_processor(self, m): 214 | header_payload = { 215 | 'table': { 216 | 'cols': [ 217 | {'id': 'A', 'label': 'country', 'type': 'string'}, 218 | { 219 | 'id': 'B', 220 | 'label': 'cnt', 221 | 'type': 'number', 222 | 'pattern': 'General', 223 | }, 224 | ], 225 | }, 226 | } 227 | m.get( 228 | 'http://docs.google.com/gviz/tq?headers=1&gid=0&' 229 | 'tq=SELECT%20%2A%20LIMIT%200', 230 | json=header_payload, 231 | ) 232 | query_payload = { 233 | 'status': 'ok', 234 | 'table': { 235 | 'cols': [ 236 | { 237 | 'id': 'count-A', 238 | 'label': 'count country', 239 | 'type': 'number', 240 | }, 241 | { 242 | 'id': 'count-B', 243 | 'label': 'count cnt', 244 | 'type': 'number', 245 | }, 246 | ], 247 | 'rows': [{'c': [{'v': 5.0}, {'v': 5.0}]}], 248 | }, 249 | } 250 | m.get( 251 | 'http://docs.google.com/gviz/tq?headers=1&gid=0&' 252 | 'tq=SELECT%20COUNT(B)%2C%20COUNT(A)', 253 | json=query_payload, 254 | ) 255 | 256 | query = 'SELECT COUNT(*) AS total FROM "http://docs.google.com/"' 257 | headers = 1 258 | results, description = execute(query, headers) 259 | Row = namedtuple('Row', 'total') 260 | self.assertEqual(results, [Row(total=5.0)]) 261 | self.assertEqual( 262 | description, 263 | [('total', Type.NUMBER, None, None, None, None, True)], 264 | ) 265 | 266 | @requests_mock.Mocker() 267 | def test_execute_with_empty_columns(self, m): 268 | header_payload = { 269 | 'table': { 270 | 'cols': [ 271 | {'id': 'A', 'label': 'country', 'type': 'string'}, 272 | { 273 | 'id': 'B', 274 | 'label': 'cnt', 275 | 'type': 'number', 276 | 'pattern': 'General', 277 | }, 278 | {'id': 'C', 'label': '', 'type': 'string'} 279 | ], 280 | }, 281 | } 282 | m.get( 283 | 'http://docs.google.com/gviz/tq?gid=0&' 284 | 'tq=SELECT%20%2A%20LIMIT%200', 285 | json=header_payload, 286 | ) 287 | query_payload = { 288 | 'status': 'ok', 289 | 'table': { 290 | 'cols': [ 291 | {'id': 'A', 'label': 'country', 'type': 'string'}, 292 | { 293 | 'id': 'B', 294 | 'label': 'cnt', 295 | 'type': 'number', 296 | 'pattern': 'General', 297 | }, 298 | {'id': 'C', 'label': '', 'type': 'string'}, 299 | ], 300 | 'rows': [{'c': [{'v': 'BR'}, {'v': 1.0, 'f': '1'}, None]}], 301 | }, 302 | } 303 | m.get( 304 | 'http://docs.google.com/gviz/tq?gid=0&tq=SELECT%20%2A', 305 | json=query_payload, 306 | ) 307 | 308 | query = 'SELECT * FROM "http://docs.google.com/"' 309 | results, description = execute(query) 310 | Row = namedtuple('Row', 'country cnt') 311 | self.assertEqual(results, [Row(country=u'BR', cnt=1.0)]) 312 | self.assertEqual( 313 | description, 314 | [ 315 | ('country', Type.STRING, None, None, None, None, True), 316 | ('cnt', Type.NUMBER, None, None, None, None, True), 317 | ], 318 | ) 319 | 320 | def test_execute_bad_query(self): 321 | with self.assertRaises(exceptions.ProgrammingError): 322 | execute('SELECT ORDER BY FROM table') 323 | 324 | def test_execute_invalid_url(self): 325 | with self.assertRaises(exceptions.InterfaceError): 326 | execute('SELECT * FROM "http://example.com/"') 327 | 328 | @requests_mock.Mocker() 329 | def test_execute_gsheets_error(self, m): 330 | header_payload = { 331 | 'table': { 332 | 'cols': [ 333 | {'id': 'A', 'label': 'country', 'type': 'string'}, 334 | { 335 | 'id': 'B', 336 | 'label': 'cnt', 337 | 'type': 'number', 338 | 'pattern': 'General', 339 | }, 340 | ], 341 | }, 342 | } 343 | m.get( 344 | 'http://docs.google.com/gviz/tq?headers=1&gid=0&' 345 | 'tq=SELECT%20%2A%20LIMIT%200', 346 | json=header_payload, 347 | ) 348 | query_payload = { 349 | 'status': 'error', 350 | 'errors': [{'detailed_message': 'Error!'}], 351 | } 352 | m.get( 353 | 'http://docs.google.com/gviz/tq?headers=1&gid=0&' 354 | 'tq=SELECT%20COUNT(B)%2C%20COUNT(A)', 355 | json=query_payload, 356 | ) 357 | 358 | query = 'SELECT COUNT(*) AS total FROM "http://docs.google.com/"' 359 | headers = 1 360 | with self.assertRaises(exceptions.ProgrammingError): 361 | execute(query, headers) 362 | -------------------------------------------------------------------------------- /tests/test_translation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import unittest 4 | 5 | from moz_sql_parser import parse 6 | 7 | from .context import exceptions, extract_column_aliases, translate 8 | 9 | 10 | class TranslationTestSuite(unittest.TestCase): 11 | 12 | def test_select(self): 13 | sql = 'SELECT country FROM "http://docs.google.com"' 14 | expected = "SELECT A" 15 | result = translate(parse(sql), {'country': 'A'}) 16 | self.assertEqual(result, expected) 17 | 18 | def test_from(self): 19 | sql = 'SELECT * FROM table' 20 | expected = 'SELECT *' 21 | result = translate(parse(sql), {}) 22 | self.assertEqual(result, expected) 23 | 24 | def test_where(self): 25 | sql = 'SELECT country FROM "http://docs.google.com" WHERE cnt > 10' 26 | expected = 'SELECT A WHERE B > 10' 27 | result = translate(parse(sql), {'country': 'A', 'cnt': 'B'}) 28 | self.assertEqual(result, expected) 29 | 30 | def test_where_groupby(self): 31 | sql = ''' 32 | SELECT 33 | country 34 | , SUM(cnt) 35 | FROM 36 | "http://docs.google.com" 37 | WHERE 38 | country != 'US' 39 | GROUP BY 40 | country 41 | ''' 42 | expected = "SELECT A, SUM(B) WHERE A <> 'US' GROUP BY A" 43 | result = translate(parse(sql), {'country': 'A', 'cnt': 'B'}) 44 | self.assertEqual(result, expected) 45 | 46 | def test_groupby(self): 47 | sql = ( 48 | 'SELECT country, SUM(cnt) FROM "http://docs.google.com" ' 49 | 'GROUP BY country' 50 | ) 51 | expected = "SELECT A, SUM(B) GROUP BY A" 52 | result = translate(parse(sql), {'country': 'A', 'cnt': 'B'}) 53 | self.assertEqual(result, expected) 54 | 55 | def test_having(self): 56 | sql = ''' 57 | SELECT 58 | country 59 | , SUM(cnt) 60 | FROM 61 | "http://docs.google.com" 62 | GROUP BY 63 | country 64 | HAVING 65 | COUNT(*) > 0 66 | ''' 67 | with self.assertRaises(exceptions.NotSupportedError): 68 | translate(parse(sql), {'country': 'A', 'cnt': 'B'}) 69 | 70 | def test_subquery(self): 71 | sql = 'SELECT * from XYZZY, ABC' 72 | with self.assertRaises(exceptions.NotSupportedError): 73 | translate(parse(sql)) 74 | 75 | def test_orderby(self): 76 | sql = ''' 77 | SELECT 78 | country 79 | , SUM(cnt) 80 | FROM 81 | "http://docs.google.com" 82 | GROUP BY 83 | country 84 | ORDER BY 85 | SUM(cnt) 86 | ''' 87 | expected = "SELECT A, SUM(B) GROUP BY A ORDER BY SUM(B)" 88 | result = translate(parse(sql), {'country': 'A', 'cnt': 'B'}) 89 | self.assertEqual(result, expected) 90 | 91 | def test_limit(self): 92 | sql = 'SELECT country FROM "http://docs.google.com" LIMIT 10' 93 | expected = 'SELECT A LIMIT 10' 94 | result = translate(parse(sql), {'country': 'A'}) 95 | self.assertEqual(result, expected) 96 | 97 | def test_offset(self): 98 | sql = 'SELECT country FROM "http://docs.google.com" LIMIT 10 OFFSET 5' 99 | expected = 'SELECT A LIMIT 10 OFFSET 5' 100 | result = translate(parse(sql), {'country': 'A'}) 101 | self.assertEqual(result, expected) 102 | 103 | def test_alias(self): 104 | sql = 'SELECT SUM(cnt) AS total FROM "http://docs.google.com"' 105 | expected = 'SELECT SUM(B)' 106 | result = translate(parse(sql), {'cnt': 'B'}) 107 | self.assertEqual(result, expected) 108 | 109 | def test_multiple_aliases(self): 110 | sql = ( 111 | 'SELECT country AS dim1, SUM(cnt) AS total ' 112 | 'FROM "http://docs.google.com" GROUP BY country' 113 | ) 114 | expected = 'SELECT A, SUM(B) GROUP BY A' 115 | result = translate(parse(sql), {'country': 'A', 'cnt': 'B'}) 116 | self.assertEqual(result, expected) 117 | 118 | def test_unalias_orderby(self): 119 | sql = ''' 120 | SELECT 121 | cnt AS value 122 | FROM 123 | "http://docs.google.com" 124 | ORDER BY 125 | value 126 | ''' 127 | expected = 'SELECT B ORDER BY B' 128 | result = translate(parse(sql), {'cnt': 'B'}) 129 | self.assertEqual(result, expected) 130 | 131 | def test_column_aliases(self): 132 | sql = 'SELECT SUM(cnt) AS total FROM "http://docs.google.com"' 133 | expected = ['total'] 134 | result = extract_column_aliases(parse(sql)) 135 | self.assertEqual(result, expected) 136 | 137 | def test_column_aliases_star(self): 138 | sql = 'SELECT * FROM "http://docs.google.com"' 139 | expected = [None] 140 | result = extract_column_aliases(parse(sql)) 141 | self.assertEqual(result, expected) 142 | 143 | def test_column_aliases_multiple(self): 144 | sql = ( 145 | 'SELECT SUM(cnt) AS total, country, gender AS dim1 ' 146 | 'FROM "http://docs.google.com"' 147 | ) 148 | expected = ['total', None, 'dim1'] 149 | result = extract_column_aliases(parse(sql)) 150 | self.assertEqual(result, expected) 151 | 152 | def test_order_by_alias(self): 153 | sql = ''' 154 | SELECT 155 | country AS country 156 | , SUM(cnt) AS "SUM(cnt)" 157 | FROM 158 | "https://docs.google.com" 159 | GROUP BY 160 | country 161 | ORDER BY 162 | "SUM(cnt)" 163 | DESC 164 | LIMIT 10 165 | ''' 166 | expected = 'SELECT A, SUM(B) GROUP BY A ORDER BY SUM(B) DESC LIMIT 10' 167 | result = translate(parse(sql), {'country': 'A', 'cnt': 'B'}) 168 | self.assertEqual(result, expected) 169 | 170 | 171 | if __name__ == '__main__': 172 | unittest.main() 173 | -------------------------------------------------------------------------------- /tests/test_url.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import unittest 4 | 5 | from .context import extract_url, get_url 6 | 7 | 8 | class UrlTestSuite(unittest.TestCase): 9 | 10 | def test_extract_url(self): 11 | query = 'SELECT * FROM "http://docs.google.com"' 12 | result = extract_url(query) 13 | expected = 'http://docs.google.com' 14 | self.assertEqual(result, expected) 15 | 16 | def test_get_url(self): 17 | url = 'http://docs.google.com' 18 | result = get_url(url, headers=1, gid=10, sheet=None) 19 | expected = 'http://docs.google.com/gviz/tq?headers=1&gid=10' 20 | self.assertEqual(result, expected) 21 | 22 | def test_remove_edit(self): 23 | url = 'http://docs.google.com/edit#gid=0' 24 | result = get_url(url, headers=1, gid=10, sheet=None) 25 | expected = 'http://docs.google.com/gviz/tq?headers=1&gid=0' 26 | self.assertEqual(result, expected) 27 | 28 | def test_url_gid_qs(self): 29 | url = 'http://docs.google.com/?gid=0' 30 | result = get_url(url, headers=1, gid=10, sheet=None) 31 | expected = 'http://docs.google.com/gviz/tq?headers=1&gid=0' 32 | self.assertEqual(result, expected) 33 | 34 | def test_url_headers_qs(self): 35 | url = 'http://docs.google.com/?gid=0&headers=2' 36 | result = get_url(url, headers=1, gid=10, sheet=None) 37 | expected = 'http://docs.google.com/gviz/tq?headers=2&gid=0' 38 | self.assertEqual(result, expected) 39 | 40 | def test_sheet_name(self): 41 | url = 'http://docs.google.com/?gid=0&headers=2&sheet=table' 42 | result = get_url(url, headers=1, gid=10, sheet=None) 43 | expected = 'http://docs.google.com/gviz/tq?headers=2&sheet=table' 44 | self.assertEqual(result, expected) 45 | 46 | def test_extract_url_bad_sql(self): 47 | query = 'SELECTSELECTSELECT' 48 | result = extract_url(query) 49 | self.assertIsNone(result) 50 | 51 | def test_extract_url_using_regex(self): 52 | query = 'INVALID FROM "http://docs.google.com"' 53 | result = extract_url(query) 54 | expected = 'http://docs.google.com' 55 | self.assertEqual(result, expected) 56 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import unittest 4 | 5 | from moz_sql_parser import parse 6 | import pyparsing 7 | 8 | from .context import format_gsheet_error, format_moz_error 9 | 10 | 11 | class UtilsTestSuite(unittest.TestCase): 12 | 13 | def test_format_moz_error(self): 14 | query = 'SELECT ))) FROM table' 15 | with self.assertRaises(pyparsing.ParseException) as context: 16 | parse(query) 17 | 18 | result = format_moz_error(query, context.exception) 19 | expected = ( 20 | 'SELECT ))) FROM table\n' 21 | ' ^\n' 22 | 'Expected {{expression1 [{[as] column_name1}]} | "*"} ' 23 | '(at char 7), (line:1, col:8)' 24 | ) 25 | self.assertEqual(result, expected) 26 | 27 | def test_format_gsheet_error(self): 28 | query = 'SELECT A + B FROM "http://docs.google.com"' 29 | translated_query = 'SELECT A + B' 30 | errors = [{ 31 | 'reason': 'invalid_query', 32 | 'detailed_message': ( 33 | "Invalid query: Can't perform the function sum on values that " 34 | "are not numbers" 35 | ), 36 | 'message': 'INVALID_QUERY', 37 | }] 38 | 39 | result = format_gsheet_error(query, translated_query, errors) 40 | expected = ( 41 | 'Original query:\n' 42 | 'SELECT A + B FROM "http://docs.google.com"\n\n' 43 | 'Translated query:\n' 44 | 'SELECT A + B\n\n' 45 | 'Error:\n' 46 | "Invalid query: Can't perform the function sum on values that " 47 | "are not numbers" 48 | ) 49 | self.assertEqual(result, expected) 50 | 51 | def test_format_gsheet_error_caret(self): 52 | query = 'SELECT A IS NULL FROM "http://docs.google.com"' 53 | translated_query = 'SELECT A IS NULL' 54 | errors = [{ 55 | 'reason': 'invalid_query', 56 | 'detailed_message': ( 57 | 'Invalid query: PARSE_ERROR: Encountered " "is" "IS "" at ' 58 | 'line 1, column 10.\nWas expecting one of:\n' 59 | ' \n' 60 | ' "where" ...\n' 61 | ' "group" ...\n' 62 | ' "pivot" ...\n' 63 | ' "order" ...\n' 64 | ' "skipping" ...\n' 65 | ' "limit" ...\n' 66 | ' "offset" ...\n' 67 | ' "label" ...\n' 68 | ' "format" ...\n' 69 | ' "options" ...\n' 70 | ' "," ...\n' 71 | ' "*" ...\n' 72 | ' "+" ...\n' 73 | ' "-" ...\n' 74 | ' "/" ...\n' 75 | ' "%" ...\n' 76 | ' "*" ...\n' 77 | ' "/" ...\n' 78 | ' "%" ...\n' 79 | ' "+" ...\n' 80 | ' "-" ...\n' 81 | ' ' 82 | ), 83 | 'message': 'INVALID_QUERY', 84 | }] 85 | 86 | result = format_gsheet_error(query, translated_query, errors) 87 | expected = ( 88 | 'Original query:\n' 89 | 'SELECT A IS NULL FROM "http://docs.google.com"\n\n' 90 | 'Translated query:\n' 91 | 'SELECT A IS NULL\n\n' 92 | 'Error:\n' 93 | 'SELECT A IS NULL\n' 94 | ' ^\n' 95 | 'Invalid query: PARSE_ERROR: Encountered " "is" "IS "" at line 1, ' 96 | 'column 10.\n' 97 | 'Was expecting one of:\n' 98 | ' \n' 99 | ' "where" ...\n' 100 | ' "group" ...\n' 101 | ' "pivot" ...\n' 102 | ' "order" ...\n' 103 | ' "skipping" ...\n' 104 | ' "limit" ...\n' 105 | ' "offset" ...\n' 106 | ' "label" ...\n' 107 | ' "format" ...\n' 108 | ' "options" ...\n' 109 | ' "," ...\n' 110 | ' "*" ...\n' 111 | ' "+" ...\n' 112 | ' "-" ...\n' 113 | ' "/" ...\n' 114 | ' "%" ...\n' 115 | ' "*" ...\n' 116 | ' "/" ...\n' 117 | ' "%" ...\n' 118 | ' "+" ...\n' 119 | ' "-" ...' 120 | ) 121 | self.assertEqual(result, expected) 122 | --------------------------------------------------------------------------------