├── .github └── workflows │ └── publish.yml ├── .gitignore ├── .travis.yml ├── LICENSE ├── MANIFEST.in ├── README.md ├── ci └── travis │ └── deploy.sh ├── odbc-cli ├── odbcli ├── __init__.py ├── __main__.py ├── app.py ├── cli.py ├── completion │ ├── __init__.py │ ├── mssqlcompleter.py │ ├── mssqlliterals │ │ ├── __init__.py │ │ ├── main.py │ │ └── sqlliterals.json │ ├── parseutils │ │ ├── __init__.py │ │ ├── ctes.py │ │ ├── meta.py │ │ ├── tables.py │ │ └── utils.py │ ├── prioritization.py │ └── sqlcompletion.py ├── config.py ├── conn.py ├── dbmetadata.py ├── disconnect_dialog.py ├── filters.py ├── layout.py ├── loginprompt.py ├── odbclirc ├── odbcstyle.py ├── preview.py ├── sidebar.py └── utils.py ├── setup.py └── tests ├── test_dbmetadata.py ├── test_parseutils.py └── test_sqlcompletion.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: odbcli 2 | on: push 3 | 4 | jobs: 5 | build-n-publish: 6 | runs-on: ubuntu-latest 7 | 8 | steps: 9 | - uses: actions/checkout@master 10 | - name: Set up Python 3.7 11 | uses: actions/setup-python@v3 12 | with: 13 | python-version: "3.7" 14 | 15 | - name: Update dev version 16 | if: github.event_name != 'push' || startsWith(github.ref, 'refs/tags') != true 17 | run: | 18 | sed -i "/__version__ = / s/\"$/.20${{github.run_number}}\"/" odbcli/__init__.py 19 | 20 | - name: Build and test 21 | run: | 22 | python -m pip install --upgrade --upgrade-strategy eager . twine pytest 23 | pytest . 24 | python setup.py sdist -d dist/ 25 | 26 | - name: Publish to Test PyPI 27 | if: github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags') 28 | uses: pypa/gh-action-pypi-publish@release/v1 29 | with: 30 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 31 | repository-url: https://test.pypi.org/legacy/ 32 | 33 | - name: Publish to PyPI 34 | if: startsWith(github.ref, 'refs/tags') 35 | uses: pypa/gh-action-pypi-publish@release/v1 36 | with: 37 | password: ${{ secrets.PYPI_API_TOKEN }} 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | cache: pip 3 | language: python 4 | 5 | env: 6 | global: 7 | - secure: YGpGnxf9XMnwFiOyXKvl8XiDodhPIN7+PzOytJ0bodNvhriaua8Y54NT04WXLzpcg59t0Bp0Yd6gUERuyZ4cslBBpHtfl74iUOJPOEk/V7H0562C7TAxXxI9+XltYBhnv55Bxq/4B1Rj0YQxJ0Dl3CMfJGovAPs6zLUA7p2fz+kkhIDcc4DC2l7DE3rYhUsRCYakjaCpGSfPT1LSOBheGPWN3taPnANr0M9pTDI+Ej+kpfip7Sxna+phKZjvy9gsdhr4Wbca52826wj29E3B2Z4qw1wBVTqCzT6oPTvLt2++etbAfDkGohXy0zBOHDy2nN9oILIci5GMYd6qivQTu2gQBM7n0woMjk3NNe6buu23t7DpI9EuVpEl/3ocPdqDudaR9kW5UyWYJd0lvEJeTUolKjDGnUiDEh/tDAvgXtWiqi8DTZJTnwDwFpGB8eKNPqJhbUL9lj7HtfyElwTzJyHViwJRvAtvbMCtLroEVZau41+cTS2aIVkJLUP7ECpIKXWNGzmhfTBWsuF2aKnEH3onG84/QbuS8ADFD5XAKUP9fyq7qK7sf2loDrhSrGG2f8laIrL53ZiUgQItmcHBqW64eXILJwHN4loucc+X/gS27mNPZQZGxdF1UmWTUUnzVCTmrrL8hQqVi4Kue7j2TLtX6GCaocZ+VTS2qVWO3RI= 8 | - secure: DAQBtOrsIbNav3Ac8FLKjOQNfnlPMX+cXRk4SfjvfxJciA8W16Aa2ZyYhC0AZiEBQbMwslYS2WyYNU3+NsQfoCJ+KoWYbnfT/qfvn+K+f0uKZTazZaOHSJA4tOs6kA88L1X5yHOwMalQqyRUEkQ1lHxcGxmjMFDi+FdsCLDVQIuLaAn+Apk3TNGXKzV7IaYNkHO5kQmTrl4/v8epJcD35lo5Yw8SkUGgo0rPosecXvZU3ZoR9CpFWfBid+NECYrv8Cp7syTT9srp0y95sE38wCwr9nttNmcCKe3fLGygh0RBLyrkNl6Ajpk7qJVaa2/McIO3+4qekXlCt7dX4UpEWUH3wDjHGb//JdsaLgIMTLWP393hm4g1SJQgnPCjVi/4hxfcGfF08dnQw4JYAHC4z1qjae5A6zoI30JHOqwH4Yu3Rj9GhFhsWNA7I2kNXg/el+df23jA7mxWfXwuY5YLck156NXij8vhLYGxltgDEtmp6U1bnppQpaJKxn/9dcXZCsbJb+x8FZ0OF3lkq82vnnXAZuQ41gTosuSg1w9DQKF6OImAG0KVXDwd5qTwp8thXdrVSRXKcUltvIT9+W7kyMEKHnusLJruOvtAuCNLArSE5keJdqtb+POLfOp9ZrUObs0j/77zfZH85VmjWspWkn9rZ1v10gYHZv6Lu8NHbCM= 9 | - secure: bJrbXXuC9O/2YazHi9D95rWiPPDyDlVUE1i9WSos3lMAbqapn2vKL2gRosrM1NcIVIhnAzwi7O8ngQRiKID7nWF1xYGhapZcIya3q6z40m0HNQus+XngzSCTNaWVCMqEASDGluHOLlc/Vnbip/3tcgUSlGHk6P1p/zzoUbMKOa0igcnCG0kxR6BkcrPL1qIWdCMTDkGlmT7IFHF4ilFkhGJWssjR4mFFZX2awJJtQ2BmudsUKeytNjUI3/yQyV3eoQo2EXQDcU34ObkZFZkymdPBqfy3sSw3To2JLc4aiq/iZzIfItUaUcWS7T9HbhUD1aDG4pWWhqCBdPDbwcc0XklgX5eL9ZXI9s++gKMO7V5nu4GSE39ShYruyq8mPiRVGO4oCUFPP1N8ZMFaXLEMg5Wa4s+VxuHSMB70l5VeL7UvorJ+Ij4djBtn2RF6IYymvpDH3q5n4W2BciVg7LRZ1GX8TtSZNae4O09qYQhMGQNIKCNykbvx3DJamcI1q54kJTCjub9RVcvzUxap0PgTxjHrGB7M90wg9bHmvTaJvl7fkFC8yrtOQ/vjwttiq5X7SCrEHaBSV9TEbLOB3jB2bIfLQ92T2USjbExetOartz/P+X1dreECHXhk9g55qAvob4k3uCE3K+WAEdbVw39mqhfUPz0iuKqTsrJJ5fNyX2Y= 10 | - secure: d6W0/c04ea+0Lb5FXzzVcV4yVSqEMwrB1DOEpRapTP50a6JumziLKhIyDyzsJMgqVmkGQsAgfDNtGOZ1k/UiA3rjSXlQ5YgIwQpr0xgntKDq1/Ha02xxBpXfaS+ts3cnygldOGKF7jtPOJ+sb3hSdVmX1+q25OoyW8VtUT3GVtK0vERKxePdNzwOmrUXihYsvMMDrebmvcQvXHyk0UhtKnPlYVfyzcqoRBhlKUkUfU1LhLDQvewWcwwg7fC+7eK3Njx7EyIZ2sxU2IzDsE933YksVmJKowZOZtJsfOh6e1dspmP5Id9bGl4wTH68IzoFXzEi+/RAo68wXRjL39G6/WaRqph7YQHtLfqak1zLM7Gh0jxYSseWyYAW29RbLV71H4isXG8X6lOQE8hq9VxLaXoWx9HHaCJWuQUZVGlcnlvOcOxlr+FKmG29Z3zm9RtsZuGuZXm3+Ndf32ZBebGr5RQg5R4kkb2R/BIcm/+FBYRdCO6zjx894VXDUdUDuuwSPj0V9EvhUEEcKxwjxezIvcyE6/GC+FirhESSU/b5ufWEIW2O3TAMJxthUN5QX9JUV4+4rNhW70Y968u/BbZBCIC6fYl4UXFNl3f+I09hs6EhMVWK88ytszn+L0dACnIpedKVeJQ46GNkiX5pwve/oLtaww1cHGWP+WJNNGQmACc= 11 | 12 | matrix: 13 | include: 14 | - python: 3.7 15 | dist: xenial 16 | sudo: required 17 | 18 | script: 19 | - cd $TRAVIS_BUILD_DIR 20 | - | 21 | if [[ "$TRAVIS_TAG" == "" ]]; then 22 | sed -i "/__version__ = / s/\"$/.20$TRAVIS_BUILD_NUMBER\"/" odbcli/__init__.py 23 | fi 24 | - pip install --upgrade --upgrade-strategy eager . twine pytest 25 | - pytest . 26 | - python setup.py sdist 27 | 28 | deploy: 29 | - provider: script 30 | skip_cleanup: true 31 | script: $TRAVIS_BUILD_DIR/ci/travis/deploy.sh 32 | on: 33 | branch: master 34 | - provider: script 35 | skip_cleanup: true 36 | script: $TRAVIS_BUILD_DIR/ci/travis/deploy.sh 37 | on: 38 | tags: true 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without modification, 4 | are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this 7 | list of conditions and the following disclaimer. 8 | 9 | * Redistributions in binary form must reproduce the above copyright notice, this 10 | list of conditions and the following disclaimer in the documentation and/or 11 | other materials provided with the distribution. 12 | 13 | * Neither the name of the {organization} nor the names of its 14 | contributors may be used to endorse or promote products derived from 15 | this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 21 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 24 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include odbcli/completion/mssqlliterals/sqlliterals.json 2 | include odbcli/odbclirc 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # odbc-cli 2 | 3 | *Please note: this package should be considered "alpha" - while you are more than welcome to use it, you should expect that getting it to work for you will require quite a bit of self-help on your part. At the same time, it may be a great opportunity for those that want to contribute.* 4 | 5 |

6 |     7 |

8 |

9 | 10 |

11 | 12 | [**odbc-cli**](https://github.com/detule/odbc-cli) is an interactive command line query tool intended to work for DataBase Management Systems (DBMS) supported by ODBC drivers. 13 | 14 | As is the case with the [remaining clients](https://github.com/dbcli/) derived from the [python prompt toolkit library](https://github.com/prompt-toolkit/python-prompt-toolkit), **odbc-cli** also supports a rich interactive command line experience, with features such as auto-completion, syntax-highlighting, multi-line queries, and query-history. 15 | 16 | Beyond these, some distinguishing features of **odbc-cli** are: 17 | 18 | - **Multi DBMS support**: In addition to supporting connections to multiple DBMS, with **odbc-cli** you can connect to, and query multiple databases in the same session. 19 | - **An integrated object browser**: Navigate between connections and objects within a database. 20 | - **Small footprint and excellent performance**: One of the main motivations is to reduce both the on-disk, as well as in-memory footprint of the [existing Microsoft SQL Server client](https://github.com/dbcli/mssql-cli/), while at the same time improve query execution, and time spent retrieving results. 21 | - **Out-of-database auto-completion**: Mostly relevant to SQL Server users, but auto-completion is "aware" of schema and table structure outside of the currently connected catalog / database. 22 | 23 | ## Installing and OS support 24 | 25 | The assumption is that the starting point is a box with a working ODBC setup. This means a driver manager (UnixODBC, for example), together with ODBC drivers that are appropriate to the DBM Systems you intend to connect to. 26 | 27 | To install the latest version of the package marked as *stable*, simply: 28 | 29 | ```sh 30 | python -m pip install odbcli 31 | ``` 32 | 33 | *Development* versions, tracking the tip of the master branch, are hosted on Test Pypi, and can be installed, for example by: 34 | 35 | ```sh 36 | python -m pip install --index-url https://test.pypi.org/simple/ odbcli 37 | ``` 38 | 39 | Notes: 40 | * In theory, this package should work under Windows, MacOS, as well as Linux. I can only test Linux; help testing and developing on the other platforms (as well as Linux) is very much welcome. 41 | * The main supporting package, [**cyanodbc**](https://github.com/cyanodbc/cyanodbc) comes as a pre-compiled wheel. It requires a modern C++ library supporting the C++14 standard. The cyanodbc Linux wheel is built on Ubuntu 16 - not exactly bleeding edge. Anything newer should be fine. 42 | * As of https://github.com/detule/odbc-cli/commit/bea22885d0483de0c1899ebc26ff853568b0e417, **odbc-cli** requires `cyanodbc` version [0.0.2.136](https://test.pypi.org/project/Cyanodbc/0.0.2.136/#files) or newer. 43 | 44 | ## Usage 45 | 46 | See the [Usage section here](https://detule.github.io/odbc-cli/index.html#Usage). 47 | 48 | ## Supported DBMS 49 | 50 | I have had a chance to test connectivity and basic functionality to the following DBM Systems: 51 | 52 | * **Microsoft SQL Server** 53 | Support and usability here should be furthest along. While I encounter (and fix) an occasional issue, I use this client in this capacity daily. 54 | 55 | Driver notes: 56 | * OEM Driver: No known issues (I test with driver version 17.5). 57 | * FreeTDS: Please use version 1.2 or newer for optimal performance (older versions do not support the SQLColumns API endpoint applied to tables out-of-currently-connected-catalog). 58 | 59 | * **MySQL** 60 | I have had a chance to test connectivity and basic functionality, but contributor help very much appreciated. 61 | 62 | * **SQLite** 63 | I have had a chance to test connectivity and basic functionality, but contributor help very much appreciated. 64 | 65 | * **PostgreSQL** 66 | I have had a chance to test connectivity and basic functionality, but contributor help very much appreciated. 67 | 68 | Driver notes: 69 | * Please consider using [psqlODBC 12.01](https://odbc.postgresql.org/docs/release.html) or newer for optimal performance (older versions, when used with a PostgreSQL 12.0, seem to have a documented bug when calling into SQLColumns). 70 | 71 | * **Snowflake** 72 | I have had a chance to test connectivity and basic functionality, but contributor help very much appreciated. 73 | 74 | Driver notes: 75 | * As of version 2.20 of their ODBC driver, consider specifying the `Database` field in the DSN configuration section in your INI files. If no `Database` is specified when connecting, their driver will report the empty string - despite being attached to a particlar catalog. Subsequently, post-connection specifying the database using `USE` works as expected. 76 | 77 | * **Other** DMB Systems with ODBC drivers not mentioned above should work with minimal, or hopefully no additional, configuration / effort. 78 | 79 | ## Reporting issues 80 | 81 | The best feature - multi DBMS support, is also a curse from a support perspective, as there are too-many-to-count combinations of: 82 | 83 | * Client platform (ex: Debian 10) 84 | * Data base system (ex: SQL Server) 85 | * Data base version (ex: 19) 86 | * ODBC driver manager (ex: unixODBC) 87 | * ODBC driver manager version (ex: 2.3.x) 88 | * ODBC driver (ex: FreeTDS) 89 | * ODBC driver version (ex: 1.2.3) 90 | 91 | that could be specific to your setup, contributing to the problem and making it difficult to replicate. Please consider including all of this information when reporting the issue, but above all be prepared that I may not be able to replicate and fix your issue (and therefore, hopefully you can contribute / code-up a solution). Since the use case for this client is so broad, the only way I see this project having decent support is if we build up a critical mass of user/developers. 92 | 93 | ## Troubleshooting 94 | 95 | ### Listing connections and connecting to databases 96 | 97 | The best way to resolve connectivity issues is to work directly in a python console. In particular, try working directly with the `cyanodbc` package in an interactive session. 98 | 99 | * When starting the client, **odbc-cli** queries the driver manager for a list of available connections by executing: 100 | 101 | ``` 102 | import cyanodbc 103 | cyanodbc.datasources() 104 | ``` 105 | 106 | Make sure this command returns a coherent output / agrees with your expectations before attempting anything else. If it does not, consult the documentaion for your driver manager and make sure all the appropriate INI files are populated accordingly. 107 | 108 | * If for example, you are attempting to connect to a DSN called `postgresql_db` - recall this should be defined and configured in the INI configuration file appropriate to your driver manager, in the background, **odbc-cli** attempts to establish a connection with a connection string similar to: 109 | 110 | ``` 111 | import cyanodbc 112 | conn = cyanodbc.connect("DSN=postgresql_db;UID=postgres;PWD=password") 113 | ``` 114 | 115 | If experiencing issues connecting to a database, make sure you can establish a connection using the method above, before moving on to troubleshoot other parts of the client. 116 | 117 | ## Acknowledgements 118 | 119 | This project would not be possible without the most excellent [python prompt toolkit library](https://github.com/prompt-toolkit/python-prompt-toolkit). In addition, idea and code sharing between the [clients that leverage this library](https://github.com/dbcli/) is rampant, and this project is no exception - a big thanks to all the `dbcli` contributors. 120 | -------------------------------------------------------------------------------- /ci/travis/deploy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ue 2 | 3 | if ! [[ "$TRAVIS_TAG" == "" ]] 4 | then 5 | echo "Deploying wheel to test pypi" 6 | twine upload --repository-url https://test.pypi.org/legacy/ -u $TEST_PYPI_USERNAME -p $TEST_PYPI_PASSWORD $TRAVIS_BUILD_DIR/dist/odbcli*.tar.gz 7 | echo "Deploying wheel to pypi" 8 | twine upload -u $PYPI_USERNAME -p $PYPI_PASSWORD $TRAVIS_BUILD_DIR/dist/odbcli*.tar.gz 9 | else 10 | echo "Deploying wheel to test pypi" 11 | twine upload --repository-url https://test.pypi.org/legacy/ -u $TEST_PYPI_USERNAME -p $TEST_PYPI_PASSWORD $TRAVIS_BUILD_DIR/dist/odbcli*.tar.gz 12 | fi 13 | -------------------------------------------------------------------------------- /odbc-cli: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | SOURCE="${BASH_SOURCE[0]}" 4 | while [ -h "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink 5 | DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )" 6 | SOURCE="$(readlink "$SOURCE")" 7 | [[ $SOURCE != /* ]] && SOURCE="$DIR/$SOURCE" # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located 8 | done 9 | DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )" 10 | 11 | # Set the python io encoding to UTF-8 by default if not set. 12 | if [ -z ${PYTHONIOENCODING+x} ]; then export PYTHONIOENCODING=UTF-8; fi 13 | 14 | export PYTHONPATH="${DIR}:${PYTHONPATH}" 15 | 16 | python -m odbcli.__main__ "$@" 17 | -------------------------------------------------------------------------------- /odbcli/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.6" 2 | -------------------------------------------------------------------------------- /odbcli/__main__.py: -------------------------------------------------------------------------------- 1 | from .cli import main 2 | 3 | if __name__ == "__main__": 4 | main() 5 | -------------------------------------------------------------------------------- /odbcli/app.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar 5 | from prompt_toolkit.enums import EditingMode 6 | from prompt_toolkit.key_binding import KeyBindings, merge_key_bindings 7 | from prompt_toolkit.keys import Keys 8 | from prompt_toolkit.key_binding.bindings.auto_suggest import load_auto_suggest_bindings 9 | from prompt_toolkit.application import Application 10 | from prompt_toolkit.key_binding.bindings.focus import focus_next 11 | from prompt_toolkit.filters import Condition, has_focus 12 | from logging.handlers import RotatingFileHandler 13 | from cyanodbc import datasources 14 | from .sidebar import myDBConn, myDBObject 15 | from .conn import sqlConnection 16 | from .completion.mssqlcompleter import MssqlCompleter 17 | from .config import get_config, config_location, ensure_dir_exists 18 | from .odbcstyle import style_factory 19 | from .layout import sqlAppLayout 20 | 21 | class ExitEX(Exception): 22 | pass 23 | 24 | class sqlApp: 25 | def __init__( 26 | self, 27 | odbclirc_file = None 28 | ) -> None: 29 | c = self.config = get_config(odbclirc_file) 30 | self.initialize_logging() 31 | self.set_default_pager(c) 32 | self.mouse_support: bool = c["main"].as_bool("mouse_support") 33 | self.fetch_chunk_multiplier = c["main"].as_int("fetch_chunk_multiplier") 34 | self.preview_limit_rows = c["main"].as_int("preview_limit_rows") 35 | self.preview_chunk_size = c["main"].as_int("preview_fetch_chunk_size") 36 | self.pager_reserve_lines = c["main"].as_int("pager_reserve_lines") 37 | self.table_format = c["main"]["table_format"] 38 | self.timing_enabled = c["main"].as_bool("timing") 39 | self.syntax_style = c["main"]["syntax_style"] 40 | self.cli_style = c["colors"] 41 | self.multiline: bool = c["main"].as_bool("multi_line") 42 | self.min_num_menu_lines = c["main"].as_int("min_num_menu_lines") 43 | 44 | self.show_exit_confirmation: bool = False 45 | self.exit_message: str = "Do you really want to exit?" 46 | 47 | self.show_expanding_object: bool = False 48 | 49 | self.show_sidebar: bool = True 50 | self.show_login_prompt: bool = False 51 | self.show_preview: bool = False 52 | self.show_disconnect_dialog: bool = False 53 | self._active_conn = None 54 | self.obj_list = [] 55 | # Flag to signal to some of the prompt toolkit structures that we need 56 | # to traverse the obj_list anew to list all the objects in the sidebar. 57 | # Added for efficiency (no need to traverse unless necessary). Updated 58 | # from the main thread always, so no need for locking. 59 | self.obj_list_changed: bool = True 60 | # This field is a list with two elements. The first is the index of 61 | # the currently selected object (0-indexed). The second is the 62 | # index of the currently selected object in the 63 | # list of objects where each object is counted with length of characters 64 | # in name + 1 multiplicity. So for example a list of objects 65 | # A 66 | # AB 67 | # ABC 68 | # ABCD 69 | # where "ABC" is selected would present the index as 5 ([1+1] + [2+1]). 70 | # This is used to track the cursor position in the sidebar document 71 | # It is recorded here, rather than elsewhere because we can track it 72 | # here far more efficiently (select_next, and select_previous). 73 | # It is important that all methods of this class that manipulate the 74 | # currently selected object also update this index. 75 | self._selected_obj_idx = [0, 0] 76 | dsns = list(datasources().keys()) 77 | if len(dsns) < 1: 78 | sys.exit("No datasources found ... exiting.") 79 | for dsn in dsns: 80 | self.obj_list.append(myDBConn( 81 | my_app = self, 82 | conn = sqlConnection(dsn = dsn), 83 | name = dsn, 84 | otype = "Connection")) 85 | for i in range(len(self.obj_list) - 1): 86 | self.obj_list[i].next_object = self.obj_list[i + 1] 87 | # Loop over side-bar when moving past the element on the bottom 88 | self.obj_list[len(self.obj_list) - 1].next_object = self.obj_list[0] 89 | self._selected_object = self.obj_list[0] 90 | self.completer = MssqlCompleter(smart_completion = True, get_conn = lambda: self.active_conn) 91 | 92 | self.application = self._create_application() 93 | 94 | @property 95 | def active_conn(self) -> sqlConnection: 96 | return self._active_conn 97 | 98 | @active_conn.setter 99 | def active_conn(self, conn: sqlConnection) -> None: 100 | self._active_conn = conn 101 | 102 | @property 103 | def selected_object(self) -> myDBObject: 104 | return self._selected_object 105 | 106 | @selected_object.setter 107 | def selected_object(self, obj) -> None: 108 | """ Avoid using / computationally expensive. 109 | Instead try using select_next / select_previous if possible. 110 | Will update _selected_obj_idx appropriately. 111 | """ 112 | cursor = 0 113 | idx = 0 114 | o = self.obj_list[0] 115 | self._selected_object = obj 116 | while o is not self._selected_object: 117 | if not o.next_object: 118 | raise IndexError 119 | cursor += len(o.name) + 1 120 | idx += 1 121 | o = o.next_object 122 | self._selected_obj_idx = [idx, cursor] 123 | 124 | def select(self, idx) -> None: 125 | """ Select the [i]-th object in the list. Will also update 126 | _selected_obj_idx appropriately. 127 | """ 128 | counter = 0 129 | cursor = 0 130 | o = self.obj_list[0] 131 | while counter < idx: 132 | if not o.next_object: 133 | raise IndexError 134 | counter += 1 135 | cursor += len(o.name) + 1 136 | o = o.next_object 137 | self._selected_object = o 138 | self._selected_obj_idx = [idx, cursor] 139 | 140 | @property 141 | def selected_object_idx(self): 142 | return self._selected_obj_idx 143 | 144 | def select_next(self) -> None: 145 | self._selected_object = self.selected_object.next_object 146 | 147 | def select_previous(self) -> None: 148 | obj = self.selected_object.parent if self.selected_object.parent is not None else self.obj_list[0] 149 | while obj.next_object is not self.selected_object: 150 | obj = obj.next_object 151 | self._selected_object = obj 152 | 153 | @property 154 | def editing_mode(self) -> EditingMode: 155 | return self.application.editing_mode 156 | 157 | @editing_mode.setter 158 | def editing_mode(self, value: EditingMode) -> None: 159 | app = self.application 160 | app.editing_mode = value 161 | 162 | @property 163 | def vi_mode(self) -> bool: 164 | return self.editing_mode == EditingMode.VI 165 | 166 | @vi_mode.setter 167 | def vi_mode(self, value: bool) -> None: 168 | if value: 169 | self.editing_mode = EditingMode.VI 170 | else: 171 | self.editing_mode = EditingMode.EMACS 172 | 173 | def set_default_pager(self, config): 174 | configured_pager = config["main"].get("pager") 175 | os_environ_pager = os.environ.get("PAGER") 176 | 177 | if configured_pager: 178 | self.logger.info( 179 | 'Default pager found in config file: "%s"', configured_pager 180 | ) 181 | os.environ["PAGER"] = configured_pager 182 | elif os_environ_pager: 183 | self.logger.info( 184 | 'Default pager found in PAGER environment variable: "%s"', 185 | os_environ_pager, 186 | ) 187 | os.environ["PAGER"] = os_environ_pager 188 | else: 189 | self.logger.info( 190 | "No default pager found in environment. Using os default pager" 191 | ) 192 | 193 | # Set default set of less recommended options, if they are not already set. 194 | # They are ignored if pager is different than less. 195 | if not os.environ.get("LESS"): 196 | os.environ["LESS"] = "-SRXF" 197 | 198 | def initialize_logging(self): 199 | log_file = self.config['main']['log_file'] 200 | if log_file == 'default': 201 | log_file = config_location() + 'odbcli.log' 202 | ensure_dir_exists(log_file) 203 | log_level = self.config['main']['log_level'] 204 | 205 | # Disable logging if value is NONE by switching to a no-op handler. 206 | # Set log level to a high value so it doesn't even waste cycles getting 207 | # called. 208 | if log_level.upper() == 'NONE': 209 | handler = logging.NullHandler() 210 | else: 211 | # creates a log buffer with max size of 20 MB and 5 backup files 212 | handler = RotatingFileHandler(os.path.expanduser(log_file), 213 | encoding='utf-8', maxBytes=1024*1024*20, backupCount=5) 214 | 215 | level_map = {'CRITICAL': logging.CRITICAL, 216 | 'ERROR': logging.ERROR, 217 | 'WARNING': logging.WARNING, 218 | 'INFO': logging.INFO, 219 | 'DEBUG': logging.DEBUG, 220 | 'NONE': logging.CRITICAL 221 | } 222 | 223 | log_level = level_map[log_level.upper()] 224 | 225 | formatter = logging.Formatter( 226 | '%(asctime)s (%(process)d/%(threadName)s) ' 227 | '%(name)s %(levelname)s - %(message)s') 228 | 229 | handler.setFormatter(formatter) 230 | 231 | root_logger = logging.getLogger('odbcli') 232 | root_logger.addHandler(handler) 233 | root_logger.setLevel(log_level) 234 | 235 | root_logger.info('Initializing odbcli logging.') 236 | root_logger.debug('Log file %r.', log_file) 237 | self.logger = logging.getLogger(__name__) 238 | 239 | def _create_application(self) -> Application: 240 | self.sql_layout = sqlAppLayout(my_app = self) 241 | kb = KeyBindings() 242 | 243 | confirmation_visible = Condition(lambda: self.show_exit_confirmation) 244 | @kb.add("c-q") 245 | def _(event): 246 | " Pressing Ctrl-Q or Ctrl-C will exit the user interface. " 247 | self.show_exit_confirmation = True 248 | 249 | @kb.add("y", filter=confirmation_visible) 250 | @kb.add("Y", filter=confirmation_visible) 251 | @kb.add("enter", filter=confirmation_visible) 252 | @kb.add("c-q", filter=confirmation_visible) 253 | def _(event): 254 | """ 255 | Really quit. 256 | """ 257 | event.app.exit(exception = ExitEX(), style="class:exiting") 258 | 259 | @kb.add(Keys.Any, filter=confirmation_visible) 260 | def _(event): 261 | """ 262 | Cancel exit. 263 | """ 264 | self.show_exit_confirmation = False 265 | 266 | # Global key bindings. 267 | @kb.add("tab", filter = Condition(lambda: self.show_preview or self.show_login_prompt)) 268 | def _(event): 269 | event.app.layout.focus_next() 270 | @kb.add("f4") 271 | def _(event): 272 | " Toggle between Emacs and Vi mode. " 273 | self.vi_mode = not self.vi_mode 274 | # apparently ctrls does this 275 | @kb.add("c-t", filter = Condition(lambda: not self.show_preview)) 276 | def _(event): 277 | """ 278 | Show/hide sidebar. 279 | """ 280 | self.show_sidebar = not self.show_sidebar 281 | if self.show_sidebar: 282 | event.app.layout.focus("sidebarbuffer") 283 | else: 284 | event.app.layout.focus_previous() 285 | 286 | sidebar_visible = Condition(lambda: self.show_sidebar and not self.show_expanding_object and not self.show_login_prompt and not self.show_preview) \ 287 | & ~has_focus("sidebarsearchbuffer") 288 | @kb.add("up", filter=sidebar_visible) 289 | @kb.add("c-p", filter=sidebar_visible) 290 | @kb.add("k", filter=sidebar_visible) 291 | def _(event): 292 | " Go to previous option. " 293 | obj = self._selected_object 294 | self.select_previous() 295 | inc = len(self.selected_object.name) + 1 # newline character 296 | if obj is self.obj_list[0]: 297 | idx = 0 298 | cursor = 0 299 | while obj is not self._selected_object: 300 | if not obj.next_object: 301 | raise IndexError 302 | cursor += len(obj.name) + 1 303 | idx += 1 304 | obj = obj.next_object 305 | self._selected_obj_idx = [idx, cursor] 306 | else: 307 | self._selected_obj_idx[0] -= 1 308 | self._selected_obj_idx[1] -= inc 309 | 310 | @kb.add("down", filter=sidebar_visible) 311 | @kb.add("c-n", filter=sidebar_visible) 312 | @kb.add("j", filter=sidebar_visible) 313 | def _(event): 314 | " Go to next option. " 315 | inc = len(self.selected_object.name) + 1 # newline character 316 | self.select_next() 317 | if self.selected_object is self.obj_list[0]: 318 | self._selected_obj_idx = [0, 0] 319 | else: 320 | self._selected_obj_idx[0] += 1 321 | self._selected_obj_idx[1] += inc 322 | 323 | @kb.add("enter", filter = sidebar_visible) 324 | def _(event): 325 | " If connection, connect. If table preview" 326 | obj = self.selected_object 327 | if type(obj).__name__ == "myDBConn" and not obj.conn.connected(): 328 | self.show_login_prompt = True 329 | event.app.layout.focus(self.sql_layout.lprompt) 330 | if type(obj).__name__ == "myDBConn" and obj.conn.connected(): 331 | # OG: some thread locking may be needed here 332 | self._active_conn = obj.conn 333 | elif obj.otype in ["table", "view", "function"]: 334 | self.show_preview = True 335 | self.show_sidebar = False 336 | event.app.layout.focus(self.sql_layout.preview) 337 | 338 | @kb.add("right", filter=sidebar_visible) 339 | @kb.add("l", filter=sidebar_visible) 340 | @kb.add(" ", filter=sidebar_visible) 341 | def _(event): 342 | " Select next value for current option. " 343 | obj = self.selected_object 344 | obj.expand() 345 | if type(obj).__name__ == "myDBConn" and not obj.conn.connected(): 346 | self.show_login_prompt = True 347 | event.app.layout.focus(self.sql_layout.lprompt) 348 | 349 | @kb.add("left", filter=sidebar_visible) 350 | @kb.add("h", filter=sidebar_visible) 351 | def _(event): 352 | " Select next value for current option. " 353 | obj = self.selected_object 354 | if type(obj).__name__ == "myDBConn" and obj.conn.connected() and obj.children is None: 355 | self.show_disconnect_dialog = True 356 | event.app.layout.focus(self.sql_layout.disconnect_dialog) 357 | else: 358 | obj.collapse() 359 | 360 | auto_suggest_bindings = load_auto_suggest_bindings() 361 | 362 | return Application( 363 | layout = self.sql_layout.layout, 364 | key_bindings = merge_key_bindings([kb, auto_suggest_bindings]), 365 | enable_page_navigation_bindings = True, 366 | style = style_factory(self.syntax_style, self.cli_style), 367 | include_default_pygments_style = False, 368 | mouse_support = self.mouse_support, 369 | full_screen = False, 370 | editing_mode = EditingMode.VI if self.config["main"].as_bool("vi") else EditingMode.EMACS 371 | ) 372 | -------------------------------------------------------------------------------- /odbcli/cli.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | A simple example of a calculator program. 4 | This could be used as inspiration for a REPL. 5 | """ 6 | import os 7 | from sys import stderr 8 | from time import time 9 | from cyanodbc import DatabaseError, datasources 10 | from click import echo_via_pager, secho 11 | from prompt_toolkit.patch_stdout import patch_stdout 12 | from prompt_toolkit.utils import get_cwidth 13 | from .app import sqlApp, ExitEX 14 | from .layout import sqlAppLayout 15 | from .conn import connStatus, executionStatus 16 | 17 | 18 | def main(): 19 | 20 | my_app = sqlApp() 21 | # with patch_stdout(): 22 | while True: 23 | try: 24 | app_res = my_app.application.run() 25 | except ExitEX: 26 | for i in range(len(my_app.obj_list)): 27 | my_app.obj_list[i].conn.close() 28 | return 29 | else: 30 | # If it's a preview query we need an indication 31 | # of where to run the query 32 | if(app_res[0] == "preview"): 33 | sql_conn = my_app.selected_object.conn 34 | else: 35 | sql_conn = my_app.active_conn 36 | if sql_conn is not None: 37 | #TODO also check that it is connected 38 | try: 39 | secho("Executing query...Ctrl-c to cancel", err = False) 40 | start = time() 41 | crsr = sql_conn.async_execute(app_res[1]) 42 | execution = time() - start 43 | secho("Query execution...done", err = False) 44 | if(app_res[0] == "preview"): 45 | sql_conn.async_fetchall(my_app.preview_chunk_size, 46 | my_app.application) 47 | continue 48 | if my_app.timing_enabled: 49 | print("Time: %0.03fs" % execution) 50 | 51 | if sql_conn.execution_status == executionStatus.FAIL: 52 | err = sql_conn.execution_err 53 | secho("Query error: %s\n" % err, err = True, fg = "red") 54 | else: 55 | if crsr.description: 56 | cols = [col.name for col in crsr.description] 57 | else: 58 | cols = [] 59 | if len(cols): 60 | ht = my_app.application.output.get_size()[0] 61 | sql_conn.async_fetchall(my_app.fetch_chunk_multiplier * 62 | (ht - 3 - my_app.pager_reserve_lines), my_app.application) 63 | formatted = sql_conn.formatted_fetch(ht - 3 - my_app.pager_reserve_lines, cols, my_app.table_format) 64 | echo_via_pager(formatted) 65 | else: 66 | secho("No rows returned\n", err = False) 67 | except KeyboardInterrupt: 68 | secho("Cancelling query...", err = True, fg = "red") 69 | sql_conn.cancel() 70 | secho("Query cancelled.", err = True, fg = "red") 71 | sql_conn.close_cursor() 72 | -------------------------------------------------------------------------------- /odbcli/completion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dbcli/odbc-cli/2b061f4d700067ee3ceaca24da86af5cd3e8b21f/odbcli/completion/__init__.py -------------------------------------------------------------------------------- /odbcli/completion/mssqlliterals/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dbcli/odbc-cli/2b061f4d700067ee3ceaca24da86af5cd3e8b21f/odbcli/completion/mssqlliterals/__init__.py -------------------------------------------------------------------------------- /odbcli/completion/mssqlliterals/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | root = os.path.dirname(__file__) 5 | literal_file = os.path.join(root, 'sqlliterals.json') 6 | 7 | with open(literal_file) as f: 8 | literals = json.load(f) 9 | 10 | 11 | def get_literals(literal_type, type_=tuple): 12 | # Where `literal_type` is one of 'keywords', 'functions', 'datatypes', 13 | # returns a tuple of literal values of that type. 14 | 15 | return type_(literals[literal_type]) 16 | -------------------------------------------------------------------------------- /odbcli/completion/mssqlliterals/sqlliterals.json: -------------------------------------------------------------------------------- 1 | { 2 | "keywords": { 3 | "ADD": [], 4 | "ALL": [], 5 | "ALTER": [ 6 | "APPLICATION ROLE", 7 | "ASSEMBLY", 8 | "ASYMMETRIC KEY", 9 | "AUTHORIZATION", 10 | "AVAILABILITY GROUP", 11 | "BROKER PRIORITY", 12 | "CERTIFICATE", 13 | "COLUMN ENCRYPTION KEY", 14 | "CREDENTIAL", 15 | "CRYPTOGRAPHIC PROVIDER", 16 | "DATABASE", 17 | "DATABASE AUDIT SPECIFICATION", 18 | "DATABASE ENCRYPTION KEY", 19 | "DATABASE HADR", 20 | "DATABASE SCOPED CREDENTIAL", 21 | "DATABASE SCOPED CONFIGURATION", 22 | "DATABASE SET Options", 23 | "ENDPOINT", 24 | "EVENT SESSION", 25 | "EXTERNAL DATA SOURCE", 26 | "EXTERNAL LIBRARY", 27 | "EXTERNAL RESOURCE POOL", 28 | "FULLTEXT CATALOG", 29 | "FULLTEXT INDEX", 30 | "FULLTEXT STOPLIST", 31 | "FUNCTION", 32 | "INDEX", 33 | "LOGIN", 34 | "MASTER KEY", 35 | "MESSAGE TYPE", 36 | "PARTITION FUNCTION", 37 | "PARTITION SCHEME", 38 | "PROCEDURE", 39 | "QUEUE", 40 | "REMOTE SERVICE BINDING", 41 | "RESOURCE GOVERNOR", 42 | "RESOURCE POOL", 43 | "ROLE", 44 | "ROUTE", 45 | "SCHEMA", 46 | "SEARCH PROPERTY LIST", 47 | "SECURITY POLICY", 48 | "SEQUENCE", 49 | "SERVER AUDIT", 50 | "SERVER AUDIT SPECIFICATION", 51 | "SERVER CONFIGURATION", 52 | "SERVER ROLE", 53 | "SERVICE", 54 | "SERVICE MASTER KEY", 55 | "SYMMETRIC KEY", 56 | "TABLE", 57 | "TRIGGER", 58 | "USER", 59 | "VIEW", 60 | "WORKLOAD GROUP", 61 | "XML SCHEMA COLLECTION" 62 | ], 63 | "ANY": [], 64 | "AND": [], 65 | "AS": [], 66 | "ASC": [], 67 | "AUTHORIZATION": [], 68 | "BACKUP": [], 69 | "BEGIN": [], 70 | "BETWEEN": [], 71 | "BREAK": [], 72 | "BROWSE": [], 73 | "BULK": [], 74 | "BY": [], 75 | "CASCADE": [], 76 | "CASE": [], 77 | "CHECK": [], 78 | "CHECKPOINT": [], 79 | "CLOSE": [], 80 | "CLUSTERED": [], 81 | "COALESCE": [], 82 | "COLLATE": [], 83 | "COLUMN": [], 84 | "COMMIT": [], 85 | "COMPUTE": [], 86 | "CONSTRAINT": [], 87 | "CONTAINS": [], 88 | "CONTAINSTABLE": [], 89 | "CONTINUE": [], 90 | "CONVERT": [], 91 | "CREATE": [ 92 | "APPLICATION ROLE", 93 | "ASSEMBLY", 94 | "ASYMMETRIC KEY", 95 | "AUTHORIZATION", 96 | "AVAILABILITY GROUP", 97 | "BROKER PRIORITY", 98 | "CERTIFICATE", 99 | "COLUMN ENCRYPTION KEY", 100 | "CREDENTIAL", 101 | "CRYPTOGRAPHIC PROVIDER", 102 | "DATABASE", 103 | "DATABASE AUDIT SPECIFICATION", 104 | "DATABASE ENCRYPTION KEY", 105 | "DATABASE HADR", 106 | "DATABASE SCOPED CREDENTIAL", 107 | "DATABASE SCOPED CONFIGURATION", 108 | "DATABASE SET Options", 109 | "ENDPOINT", 110 | "EVENT SESSION", 111 | "EXTERNAL DATA SOURCE", 112 | "EXTERNAL LIBRARY", 113 | "EXTERNAL RESOURCE POOL", 114 | "FULLTEXT CATALOG", 115 | "FULLTEXT INDEX", 116 | "FULLTEXT STOPLIST", 117 | "FUNCTION", 118 | "INDEX", 119 | "LOGIN", 120 | "MASTER KEY", 121 | "MESSAGE TYPE", 122 | "PARTITION FUNCTION", 123 | "PARTITION SCHEME", 124 | "PROCEDURE", 125 | "QUEUE", 126 | "REMOTE SERVICE BINDING", 127 | "RESOURCE GOVERNOR", 128 | "RESOURCE POOL", 129 | "ROLE", 130 | "ROUTE", 131 | "SCHEMA", 132 | "SEARCH PROPERTY LIST", 133 | "SECURITY POLICY", 134 | "SEQUENCE", 135 | "SERVER AUDIT", 136 | "SERVER AUDIT SPECIFICATION", 137 | "SERVER CONFIGURATION", 138 | "SERVER ROLE", 139 | "SERVICE", 140 | "SERVICE MASTER KEY", 141 | "SYMMETRIC KEY", 142 | "TABLE", 143 | "TRIGGER", 144 | "USER", 145 | "VIEW", 146 | "WORKLOAD GROUP", 147 | "XML SCHEMA COLLECTION" 148 | ], 149 | "CROSS": [], 150 | "CURRENT": [], 151 | "CURRENT_DATE": [], 152 | "CURRENT_TIME": [], 153 | "CURRENT_TIMESTAMP": [], 154 | "CURRENT_USER": [], 155 | "CURSOR": [], 156 | "DATABASE": [], 157 | "DBCC": [], 158 | "DEALLOCATE": [], 159 | "DECLARE": [], 160 | "DEFAULT": [], 161 | "DELETE": [], 162 | "DENY": [], 163 | "DESC": [], 164 | "DISK": [], 165 | "DISTINCT": [], 166 | "DISTRIBUTED": [], 167 | "DOUBLE": [], 168 | "DROP": [ 169 | "APPLICATION ROLE", 170 | "ASSEMBLY", 171 | "ASYMMETRIC KEY", 172 | "AUTHORIZATION", 173 | "AVAILABILITY GROUP", 174 | "BROKER PRIORITY", 175 | "CERTIFICATE", 176 | "COLUMN ENCRYPTION KEY", 177 | "CREDENTIAL", 178 | "CRYPTOGRAPHIC PROVIDER", 179 | "DATABASE", 180 | "DATABASE AUDIT SPECIFICATION", 181 | "DATABASE ENCRYPTION KEY", 182 | "DATABASE HADR", 183 | "DATABASE SCOPED CREDENTIAL", 184 | "DATABASE SCOPED CONFIGURATION", 185 | "DATABASE SET Options", 186 | "ENDPOINT", 187 | "EVENT SESSION", 188 | "EXTERNAL DATA SOURCE", 189 | "EXTERNAL LIBRARY", 190 | "EXTERNAL RESOURCE POOL", 191 | "FULLTEXT CATALOG", 192 | "FULLTEXT INDEX", 193 | "FULLTEXT STOPLIST", 194 | "FUNCTION", 195 | "INDEX", 196 | "LOGIN", 197 | "MASTER KEY", 198 | "MESSAGE TYPE", 199 | "PARTITION FUNCTION", 200 | "PARTITION SCHEME", 201 | "PROCEDURE", 202 | "QUEUE", 203 | "REMOTE SERVICE BINDING", 204 | "RESOURCE GOVERNOR", 205 | "RESOURCE POOL", 206 | "ROLE", 207 | "ROUTE", 208 | "SCHEMA", 209 | "SEARCH PROPERTY LIST", 210 | "SECURITY POLICY", 211 | "SEQUENCE", 212 | "SERVER AUDIT", 213 | "SERVER AUDIT SPECIFICATION", 214 | "SERVER CONFIGURATION", 215 | "SERVER ROLE", 216 | "SERVICE", 217 | "SERVICE MASTER KEY", 218 | "SYMMETRIC KEY", 219 | "TABLE", 220 | "TRIGGER", 221 | "USER", 222 | "VIEW", 223 | "WORKLOAD GROUP", 224 | "XML SCHEMA COLLECTION" 225 | ], 226 | "DUMP": [], 227 | "ELSE": [], 228 | "END": [], 229 | "ERRLVL": [], 230 | "ESCAPE": [], 231 | "EXCEPT": [], 232 | "EXEC": [], 233 | "EXECUTE": [], 234 | "EXISTS": [], 235 | "EXIT": [], 236 | "EXTERNAL": [], 237 | "FETCH": [], 238 | "FILE": [], 239 | "FILLFACTOR": [], 240 | "FOR": [], 241 | "FOREIGN": [], 242 | "FREETEXT": [], 243 | "FREETEXTTABLE": [], 244 | "FROM": [], 245 | "FULL": [], 246 | "FUNCTION": [], 247 | "GOTO": [], 248 | "GRANT": [], 249 | "GROUP": [], 250 | "GROUP BY": [], 251 | "HAVING": [], 252 | "HOLDLOCK": [], 253 | "IDENTITY": [], 254 | "IDENTITY_INSERT": [], 255 | "IDENTITYCOL": [], 256 | "IF": [], 257 | "IN": [], 258 | "INDEX": [], 259 | "INNER": [], 260 | "INSERT": [], 261 | "INTERSECT": [], 262 | "INTO": [], 263 | "IS": [], 264 | "JOIN": [], 265 | "KEY": [], 266 | "KILL": [], 267 | "LEFT": [], 268 | "LIKE": [], 269 | "LIMIT": [], 270 | "LINENO": [], 271 | "LOAD": [], 272 | "MERGE": [], 273 | "NATIONAL": [], 274 | "NOCHECK": [], 275 | "NONCLUSTERED": [], 276 | "NOT": [], 277 | "NULL": [], 278 | "NULLIF": [], 279 | "OF": [], 280 | "OFF": [], 281 | "OFFSETS": [], 282 | "ON": [], 283 | "OPEN": [], 284 | "OPENDATASOURCE": [], 285 | "OPENQUERY": [], 286 | "OPENROWSET": [], 287 | "OPENXML": [], 288 | "OPTION": [], 289 | "OR": [], 290 | "ORDER": [], 291 | "OUTER": [], 292 | "OVER": [], 293 | "PERCENT": [], 294 | "PIVOT": [], 295 | "PLAN": [], 296 | "PRECISION": [], 297 | "PRIMARY": [], 298 | "PRINT": [], 299 | "PROC": [], 300 | "PROCEDURE": [], 301 | "PUBLIC": [], 302 | "RAISERROR": [], 303 | "READ": [], 304 | "READTEXT": [], 305 | "RECONFIGURE": [], 306 | "REFERENCES": [], 307 | "REPLICATION": [], 308 | "RESTORE": [], 309 | "RESTRICT": [], 310 | "RETURN": [], 311 | "REVERT": [], 312 | "REVOKE": [], 313 | "RIGHT": [], 314 | "ROLLBACK": [], 315 | "ROWCOUNT": [], 316 | "ROWGUIDCOL": [], 317 | "RULE": [], 318 | "SAVE": [], 319 | "SCHEMA": [], 320 | "SECURITYAUDIT": [], 321 | "SELECT": [], 322 | "SEMANTICKEYPHRASETABLE": [], 323 | "SEMANTICSIMILARITYDETAILSTABLE": [], 324 | "SEMANTICSIMILARITYTABLE": [], 325 | "SESSION_USER": [], 326 | "SET": [], 327 | "SETUSER": [], 328 | "SHUTDOWN": [], 329 | "SOME": [], 330 | "STATISTICS": [], 331 | "SYSTEM_USER": [], 332 | "TABLE": [], 333 | "TABLESAMPLE": [], 334 | "TEXTSIZE": [], 335 | "THEN": [], 336 | "TO": [], 337 | "TOP": [], 338 | "TRAN": [], 339 | "TRANSFER": [], 340 | "TRANSACTION": [], 341 | "TRIGGER": [], 342 | "TRUNCATE": [], 343 | "TRY_CONVERT": [], 344 | "TSEQUAL": [], 345 | "UNION": [], 346 | "UNIQUE": [], 347 | "UNPIVOT": [], 348 | "UPDATE": [], 349 | "UPDATETEXT": [], 350 | "USE": [], 351 | "USER": [], 352 | "VALUES": [], 353 | "VARYING": [], 354 | "VIEW": [], 355 | "WAITFOR": [], 356 | "WHEN": [], 357 | "WHERE": [], 358 | "WHILE": [], 359 | "WITH": [], 360 | "WITHIN GROUP": [], 361 | "WRITETEXT": [] 362 | }, 363 | "functions": [ 364 | "AVG", 365 | "CHECKSUM_AGG", 366 | "COUNT", 367 | "COUNT_BIG", 368 | "GROUPING", 369 | "GROUPING_ID", 370 | "MAX", 371 | "MIN", 372 | "STDEV", 373 | "STDEVP", 374 | "STRING_SPLIT", 375 | "SUM", 376 | "VAR", 377 | "VARP" 378 | ], 379 | "datatypes": [ 380 | "BIGINT", 381 | "BINARY", 382 | "BIT", 383 | "CHAR", 384 | "CURSOR", 385 | "DATE", 386 | "DATETIME", 387 | "DATETIME2", 388 | "DATETIMEOFFSET", 389 | "DECIMAL", 390 | "FLOAT", 391 | "HIERARCHYID", 392 | "IMAGE", 393 | "INT", 394 | "MONEY", 395 | "NCHAR", 396 | "NTEXT", 397 | "NUMERIC", 398 | "NVARCHAR", 399 | "REAL", 400 | "ROWVERSION", 401 | "SMALLDATETIME", 402 | "SMALLINT", 403 | "SMALLMONEY", 404 | "SQL_VARIANT", 405 | "TABLE", 406 | "TEXT", 407 | "TIME", 408 | "TINYINT", 409 | "UNIQUEIDENTIFIER", 410 | "VARBINARY", 411 | "VARCHAR", 412 | "XML" 413 | ], 414 | "reserved": [ 415 | ] 416 | } 417 | -------------------------------------------------------------------------------- /odbcli/completion/parseutils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dbcli/odbc-cli/2b061f4d700067ee3ceaca24da86af5cd3e8b21f/odbcli/completion/parseutils/__init__.py -------------------------------------------------------------------------------- /odbcli/completion/parseutils/ctes.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | from collections import namedtuple 3 | from sqlparse import parse 4 | from sqlparse.tokens import Keyword, CTE, DML 5 | from sqlparse.sql import Identifier, IdentifierList, Parenthesis 6 | from .meta import TableMetadata, ColumnMetadata 7 | 8 | 9 | # TableExpression is a namedtuple representing a CTE, used internally 10 | # name: cte alias assigned in the query 11 | # columns: list of column names 12 | # start: index into the original string of the left parens starting the CTE 13 | # stop: index into the original string of the right parens ending the CTE 14 | TableExpression = namedtuple('TableExpression', 'name columns start stop') 15 | 16 | 17 | def isolate_query_ctes(full_text, text_before_cursor): 18 | """Simplify a query by converting CTEs into table metadata objects 19 | """ 20 | 21 | if not full_text: 22 | return full_text, text_before_cursor, tuple() 23 | 24 | ctes, _ = extract_ctes(full_text) 25 | if not ctes: 26 | return full_text, text_before_cursor, () 27 | 28 | current_position = len(text_before_cursor) 29 | meta = [] 30 | 31 | for cte in ctes: 32 | if cte.start < current_position < cte.stop: 33 | # Currently editing a cte - treat its body as the current full_text 34 | text_before_cursor = full_text[cte.start:current_position] 35 | full_text = full_text[cte.start:cte.stop] 36 | return full_text, text_before_cursor, meta 37 | 38 | # Append this cte to the list of available table metadata 39 | cols = (ColumnMetadata(name, None, ()) for name in cte.columns) 40 | meta.append(TableMetadata(cte.name, cols)) 41 | 42 | # Editing past the last cte (ie the main body of the query) 43 | full_text = full_text[ctes[-1].stop:] 44 | text_before_cursor = text_before_cursor[ctes[-1].stop:current_position] 45 | 46 | return full_text, text_before_cursor, tuple(meta) 47 | 48 | 49 | def extract_ctes(sql): 50 | """ Extract constant table expresseions from a query 51 | 52 | Returns tuple (ctes, remainder_sql) 53 | 54 | ctes is a list of TableExpression namedtuples 55 | remainder_sql is the text from the original query after the CTEs have 56 | been stripped. 57 | """ 58 | 59 | p = parse(sql)[0] 60 | 61 | # Make sure the first meaningful token is "WITH" which is necessary to 62 | # define CTEs 63 | idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True) 64 | if not (tok and tok.ttype == CTE): 65 | return [], sql 66 | 67 | # Get the next (meaningful) token, which should be the first CTE 68 | idx, tok = p.token_next(idx) 69 | if not tok: 70 | return ([], '') 71 | start_pos = token_start_pos(p.tokens, idx) 72 | ctes = [] 73 | 74 | if isinstance(tok, IdentifierList): 75 | # Multiple ctes 76 | for t in tok.get_identifiers(): 77 | cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t)) 78 | cte = get_cte_from_token(t, start_pos + cte_start_offset) 79 | if not cte: 80 | continue 81 | ctes.append(cte) 82 | elif isinstance(tok, Identifier): 83 | # A single CTE 84 | cte = get_cte_from_token(tok, start_pos) 85 | if cte: 86 | ctes.append(cte) 87 | 88 | idx = p.token_index(tok) + 1 89 | 90 | # Collapse everything after the ctes into a remainder query 91 | remainder = u''.join(str(tok) for tok in p.tokens[idx:]) 92 | 93 | return ctes, remainder 94 | 95 | 96 | def get_cte_from_token(tok, pos0): 97 | cte_name = tok.get_real_name() 98 | if not cte_name: 99 | return None 100 | 101 | # Find the start position of the opening parens enclosing the cte body 102 | idx, parens = tok.token_next_by(Parenthesis) 103 | if not parens: 104 | return None 105 | 106 | start_pos = pos0 + token_start_pos(tok.tokens, idx) 107 | cte_len = len(str(parens)) # includes parens 108 | stop_pos = start_pos + cte_len 109 | 110 | column_names = extract_column_names(parens) 111 | 112 | return TableExpression(cte_name, column_names, start_pos, stop_pos) 113 | 114 | 115 | def extract_column_names(parsed): 116 | # Find the first DML token to check if it's a SELECT or INSERT/UPDATE/DELETE 117 | idx, tok = parsed.token_next_by(t=DML) 118 | tok_val = tok and tok.value.lower() 119 | 120 | if tok_val in ('insert', 'update', 'delete'): 121 | # Jump ahead to the RETURNING clause where the list of column names is 122 | idx, tok = parsed.token_next_by(idx, (Keyword, 'returning')) 123 | elif not tok_val == 'select': 124 | # Must be invalid CTE 125 | return () 126 | 127 | # The next token should be either a column name, or a list of column names 128 | idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True) 129 | return tuple(t.get_name() for t in _identifiers(tok)) 130 | 131 | 132 | def token_start_pos(tokens, idx): 133 | return sum(len(str(t)) for t in tokens[:idx]) 134 | 135 | 136 | def _identifiers(tok): 137 | if isinstance(tok, IdentifierList): 138 | for t in tok.get_identifiers(): 139 | # NB: IdentifierList.get_identifiers() can return non-identifiers! 140 | if isinstance(t, Identifier): 141 | yield t 142 | elif isinstance(tok, Identifier): 143 | yield tok 144 | 145 | -------------------------------------------------------------------------------- /odbcli/completion/parseutils/meta.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=too-many-instance-attributes 2 | # pylint: disable=too-many-arguments 3 | 4 | from __future__ import print_function, unicode_literals 5 | from collections import namedtuple 6 | 7 | _ColumnMetadata = namedtuple( 8 | 'ColumnMetadata', 9 | ['name', 'datatype', 'foreignkeys', 'default', 'has_default'] 10 | ) 11 | 12 | 13 | def ColumnMetadata( 14 | name, datatype, foreignkeys=None, default=None, has_default=False 15 | ): 16 | return _ColumnMetadata( 17 | name, datatype, foreignkeys or [], default, has_default 18 | ) 19 | 20 | 21 | ForeignKey = namedtuple('ForeignKey', ['parentschema', 'parenttable', 22 | 'parentcolumn', 'childschema', 'childtable', 'childcolumn']) 23 | TableMetadata = namedtuple('TableMetadata', 'name columns') 24 | 25 | 26 | def parse_defaults(defaults_string): 27 | """ 28 | Yields default values for a function, given the string provided. 29 | """ 30 | if not defaults_string: 31 | return 32 | current = '' 33 | in_quote = None 34 | for char in defaults_string: 35 | if current == '' and char == ' ': 36 | # Skip space after comma separating default expressions 37 | continue 38 | if char in ('"', '\''): 39 | if in_quote and char == in_quote: 40 | # End quote 41 | in_quote = None 42 | elif not in_quote: 43 | # Begin quote 44 | in_quote = char 45 | elif char == ',' and not in_quote: 46 | # End of expression 47 | yield current 48 | current = '' 49 | continue 50 | current += char 51 | yield current 52 | 53 | 54 | class FunctionMetadata: 55 | 56 | def __init__( 57 | self, schema_name, func_name, arg_names, arg_types, arg_modes, 58 | return_type, is_aggregate, is_window, is_set_returning, arg_defaults 59 | ): 60 | """Class for describing a postgresql function""" 61 | 62 | self.schema_name = schema_name 63 | self.func_name = func_name 64 | 65 | self.arg_modes = tuple(arg_modes) if arg_modes else None 66 | self.arg_names = tuple(arg_names) if arg_names else None 67 | 68 | # Be flexible in not requiring arg_types -- use None as a placeholder 69 | # for each arg. (Used for compatibility with old versions of postgresql 70 | # where such info is hard to get. 71 | if arg_types: 72 | self.arg_types = tuple(arg_types) 73 | elif arg_modes: 74 | self.arg_types = tuple([None] * len(arg_modes)) 75 | elif arg_names: 76 | self.arg_types = tuple([None] * len(arg_names)) 77 | else: 78 | self.arg_types = None 79 | 80 | self.arg_defaults = tuple(parse_defaults(arg_defaults)) 81 | 82 | self.return_type = return_type.strip() 83 | self.is_aggregate = is_aggregate 84 | self.is_window = is_window 85 | self.is_set_returning = is_set_returning 86 | 87 | def __eq__(self, other): 88 | return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ 89 | 90 | def __ne__(self, other): 91 | return not self.__eq__(other) 92 | 93 | def _signature(self): 94 | return ( 95 | self.schema_name, self.func_name, self.arg_names, self.arg_types, 96 | self.arg_modes, self.return_type, self.is_aggregate, 97 | self.is_window, self.is_set_returning, self.arg_defaults 98 | ) 99 | 100 | def __hash__(self): 101 | return hash(self._signature()) 102 | 103 | def __repr__(self): 104 | return ( 105 | ( 106 | '%s(schema_name=%r, func_name=%r, arg_names=%r, ' 107 | 'arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, ' 108 | 'is_window=%r, is_set_returning=%r, arg_defaults=%r)' 109 | ) % (self.__class__.__name__, self.schema_name, self.func_name, self.arg_names, 110 | self.arg_types, self.arg_modes, self.return_type, self.is_aggregate, 111 | self.is_window, self.is_set_returning, self.arg_defaults) 112 | ) 113 | 114 | def has_variadic(self): 115 | return self.arg_modes and any( 116 | arg_mode == 'v' for arg_mode in self.arg_modes) 117 | 118 | def args(self): 119 | """Returns a list of input-parameter ColumnMetadata namedtuples.""" 120 | if not self.arg_names: 121 | return [] 122 | modes = self.arg_modes or ['i'] * len(self.arg_names) 123 | args = [ 124 | (name, typ) 125 | for name, typ, mode in zip(self.arg_names, self.arg_types, modes) 126 | if mode in ('i', 'b', 'v') # IN, INOUT, VARIADIC 127 | ] 128 | 129 | def arg(name, typ, num): 130 | num_args = len(args) 131 | num_defaults = len(self.arg_defaults) 132 | has_default = num + num_defaults >= num_args 133 | default = ( 134 | self.arg_defaults[num - num_args + num_defaults] if has_default 135 | else None 136 | ) 137 | return ColumnMetadata(name, typ, [], default, has_default) 138 | 139 | return [arg(name, typ, num) for num, (name, typ) in enumerate(args)] 140 | 141 | def fields(self): 142 | """Returns a list of output-field ColumnMetadata namedtuples""" 143 | 144 | if self.return_type.lower() == 'void': 145 | return [] 146 | if not self.arg_modes: 147 | # For functions without output parameters, the function name 148 | # is used as the name of the output column. 149 | # E.g. 'SELECT unnest FROM unnest(...);' 150 | return [ColumnMetadata(self.func_name, self.return_type, [])] 151 | 152 | return [ColumnMetadata(name, typ, []) 153 | for name, typ, mode in zip( 154 | self.arg_names, self.arg_types, self.arg_modes) 155 | if mode in ('o', 'b', 't')] # OUT, INOUT, TABLE 156 | 157 | -------------------------------------------------------------------------------- /odbcli/completion/parseutils/tables.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from collections import namedtuple 3 | import sqlparse 4 | from sqlparse.sql import IdentifierList, Identifier, Function 5 | from sqlparse.tokens import Keyword, DML, Punctuation, Whitespace 6 | 7 | TableReference = namedtuple('TableReference', ['catalog', 'schema', 'name', 'alias', 8 | 'is_function']) 9 | TableReference.ref = property(lambda self: self.alias or ( 10 | self.name if self.name.islower() or self.name[0] == '"' 11 | else '"' + self.name + '"')) 12 | 13 | 14 | # This code is borrowed from sqlparse example script. 15 | # 16 | def is_subselect(parsed): 17 | if not parsed.is_group: 18 | return False 19 | for item in parsed.tokens: 20 | if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT', 21 | 'UPDATE', 'CREATE', 'DELETE'): 22 | return True 23 | return False 24 | 25 | 26 | def _identifier_is_function(identifier): 27 | return any(isinstance(t, Function) for t in identifier.tokens) 28 | 29 | 30 | def extract_from_part(parsed, stop_at_punctuation=True): 31 | tbl_prefix_seen = False 32 | for item in parsed.tokens: 33 | if tbl_prefix_seen: 34 | if is_subselect(item): 35 | for x in extract_from_part(item, stop_at_punctuation): 36 | yield x 37 | elif stop_at_punctuation and item.ttype is Punctuation: 38 | # An incomplete nested select won't be recognized correctly as a 39 | # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes 40 | # the second FROM to trigger this elif condition resulting in a 41 | # StopIteration. So we need to ignore the keyword if the keyword 42 | # FROM. 43 | # Also 'SELECT * FROM abc JOIN def' will trigger this elif 44 | # condition. So we need to ignore the keyword JOIN and its variants 45 | # INNER JOIN, FULL OUTER JOIN, etc. 46 | return 47 | elif item.ttype is Keyword and ( 48 | not item.value.upper() == 'FROM') and \ 49 | (not item.value.upper().endswith('JOIN')): 50 | tbl_prefix_seen = False 51 | else: 52 | yield item 53 | elif item.ttype is Keyword or item.ttype is Keyword.DML: 54 | item_val = item.value.upper() 55 | if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or 56 | item_val.endswith('JOIN')): 57 | tbl_prefix_seen = True 58 | # 'SELECT a, FROM abc' will detect FROM as part of the column list. 59 | # So this check here is necessary. 60 | elif isinstance(item, IdentifierList): 61 | for identifier in item.get_identifiers(): 62 | if (identifier.ttype is Keyword and 63 | identifier.value.upper() == 'FROM'): 64 | tbl_prefix_seen = True 65 | break 66 | 67 | 68 | def extract_table_identifiers(token_stream, allow_functions=True): 69 | """yields tuples of TableReference namedtuples""" 70 | 71 | # We need to do some massaging of the names because postgres is case- 72 | # insensitive and '"Foo"' is not the same table as 'Foo' (while 'foo' is) 73 | def parse_identifier(item): 74 | alias = item.get_alias() 75 | sp_idx = item.token_next_by(t = Whitespace)[0] or len(item.tokens) 76 | item_rev = Identifier(list(reversed(item.tokens[:sp_idx]))) 77 | name = item_rev._get_first_name(real_name = True) 78 | alias = alias or name 79 | dot_idx, _ = item_rev.token_next_by(m=(Punctuation, '.')) 80 | if dot_idx is not None: 81 | schema_name = item_rev._get_first_name(dot_idx, real_name = True) 82 | dot_idx, _ = item_rev.token_next_by(m=(Punctuation, '.'), idx = dot_idx) 83 | if dot_idx is not None: 84 | catalog_name = item_rev._get_first_name(dot_idx, real_name = True) 85 | else: 86 | catalog_name = None 87 | else: 88 | schema_name = None 89 | catalog_name = None 90 | # TODO: this business below needs help 91 | # for one we need to apply this logic to catalog_name 92 | # then the logic around name_quoted = quote_count > 2 obviously 93 | # doesn't work. Finally, quotechar needs to be customized 94 | schema_quoted = schema_name and item.value[0] == '"' 95 | if schema_name and not schema_quoted: 96 | schema_name = schema_name.lower() 97 | quote_count = item.value.count('"') 98 | name_quoted = quote_count > 2 or (quote_count and not schema_quoted) 99 | alias_quoted = alias and item.value[-1] == '"' 100 | if alias_quoted or name_quoted and not alias and name.islower(): 101 | alias = '"' + (alias or name) + '"' 102 | if name and not name_quoted and not name.islower(): 103 | if not alias: 104 | alias = name 105 | name = name.lower() 106 | return catalog_name, schema_name, name, alias 107 | 108 | for item in token_stream: 109 | if isinstance(item, IdentifierList): 110 | for identifier in item.get_identifiers(): 111 | # Sometimes Keywords (such as FROM ) are classified as 112 | # identifiers which don't have the get_real_name() method. 113 | try: 114 | catalog_name, schema_name, real_name, alias = parse_identifier(identifier) 115 | is_function = allow_functions and _identifier_is_function(identifier) 116 | except AttributeError: 117 | continue 118 | if real_name: 119 | yield TableReference(catalog_name, schema_name, real_name, 120 | identifier.get_alias(), is_function) 121 | elif isinstance(item, Identifier): 122 | catalog_name, schema_name, real_name, alias = parse_identifier(item) 123 | is_function = allow_functions and _identifier_is_function(item) 124 | 125 | yield TableReference(catalog_name, schema_name, real_name, alias, is_function) 126 | elif isinstance(item, Function): 127 | catalog_name, schema_name, real_name, alias = parse_identifier(item) 128 | yield TableReference(None, None, real_name, alias, allow_functions) 129 | 130 | 131 | # extract_tables is inspired from examples in the sqlparse lib. 132 | def extract_tables(sql): 133 | """Extract the table names from an SQL statment. 134 | 135 | Returns a list of TableReference namedtuples 136 | 137 | """ 138 | parsed = sqlparse.parse(sql) 139 | if not parsed: 140 | return () 141 | 142 | # INSERT statements must stop looking for tables at the sign of first 143 | # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2) 144 | # abc is the table name, but if we don't stop at the first lparen, then 145 | # we'll identify abc, col1 and col2 as table names. 146 | insert_stmt = parsed[0].token_first().value.lower() == 'insert' 147 | stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt) 148 | 149 | # Kludge: sqlparse mistakenly identifies insert statements as 150 | # function calls due to the parenthesized column list, e.g. interprets 151 | # "insert into foo (bar, baz)" as a function call to foo with arguments 152 | # (bar, baz). So don't allow any identifiers in insert statements 153 | # to have is_function=True 154 | identifiers = extract_table_identifiers(stream, 155 | allow_functions=not insert_stmt) 156 | # In the case 'sche.', we get an empty TableReference; remove that 157 | return tuple(i for i in identifiers if i.name) 158 | 159 | -------------------------------------------------------------------------------- /odbcli/completion/parseutils/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import re 3 | import sqlparse 4 | from sqlparse.sql import Identifier 5 | from sqlparse.tokens import Token, Error 6 | 7 | cleanup_regex = { 8 | # This matches only alphanumerics and underscores. 9 | 'alphanum_underscore': re.compile(r'(\w+)$'), 10 | # This matches everything except spaces, parens, colon, and comma 11 | 'many_punctuations': re.compile(r'([^():,\s]+)$'), 12 | # This matches everything except spaces, parens, colon, comma, and period 13 | 'most_punctuations': re.compile(r'([^\.():,\s]+)$'), 14 | # This matches everything except a space. 15 | 'all_punctuations': re.compile(r'([^\s]+)$'), 16 | } 17 | 18 | 19 | def last_word(text, include='alphanum_underscore'): 20 | r""" 21 | Find the last word in a sentence. 22 | 23 | >>> last_word('abc') 24 | 'abc' 25 | >>> last_word(' abc') 26 | 'abc' 27 | >>> last_word('') 28 | '' 29 | >>> last_word(' ') 30 | '' 31 | >>> last_word('abc ') 32 | '' 33 | >>> last_word('abc def') 34 | 'def' 35 | >>> last_word('abc def ') 36 | '' 37 | >>> last_word('abc def;') 38 | '' 39 | >>> last_word('bac $def') 40 | 'def' 41 | >>> last_word('bac $def', include='most_punctuations') 42 | '$def' 43 | >>> last_word('bac \def', include='most_punctuations') 44 | '\\\\def' 45 | >>> last_word('bac \def;', include='most_punctuations') 46 | '\\\\def;' 47 | >>> last_word('bac::def', include='most_punctuations') 48 | 'def' 49 | >>> last_word('"foo*bar', include='most_punctuations') 50 | '"foo*bar' 51 | """ 52 | 53 | if not text: # Empty string 54 | return '' 55 | 56 | if text[-1].isspace(): 57 | return '' 58 | regex = cleanup_regex[include] 59 | matches = regex.search(text) 60 | if matches: 61 | return matches.group(0) 62 | return '' 63 | 64 | 65 | def find_prev_keyword(sql, n_skip=0): 66 | """ Find the last sql keyword in an SQL statement 67 | 68 | Returns the value of the last keyword, and the text of the query with 69 | everything after the last keyword stripped 70 | """ 71 | if not sql.strip(): 72 | return None, '' 73 | 74 | parsed = sqlparse.parse(sql)[0] 75 | flattened = list(parsed.flatten()) 76 | flattened = flattened[:len(flattened) - n_skip] 77 | 78 | logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN') 79 | 80 | for t in reversed(flattened): 81 | if t.value == '(' or (t.is_keyword and 82 | (t.value.upper() not in logical_operators) 83 | ): 84 | # Find the location of token t in the original parsed statement 85 | # We can't use parsed.token_index(t) because t may be a child token 86 | # inside a TokenList, in which case token_index thows an error 87 | # Minimal example: 88 | # p = sqlparse.parse('select * from foo where bar') 89 | # t = list(p.flatten())[-3] # The "Where" token 90 | # p.token_index(t) # Throws ValueError: not in list 91 | idx = flattened.index(t) 92 | 93 | # Combine the string values of all tokens in the original list 94 | # up to and including the target keyword token t, to produce a 95 | # query string with everything after the keyword token removed 96 | text = ''.join(tok.value for tok in flattened[:idx + 1]) 97 | return t, text 98 | 99 | return None, '' 100 | 101 | 102 | # Postgresql dollar quote signs look like `$$` or `$tag$` 103 | dollar_quote_regex = re.compile(r'^\$[^$]*\$$') 104 | 105 | 106 | def is_open_quote(sql): 107 | """Returns true if the query contains an unclosed quote""" 108 | 109 | # parsed can contain one or more semi-colon separated commands 110 | parsed = sqlparse.parse(sql) 111 | return any(_parsed_is_open_quote(p) for p in parsed) 112 | 113 | 114 | def _parsed_is_open_quote(parsed): 115 | # Look for unmatched single quotes, or unmatched dollar sign quotes 116 | return any(tok.match(Token.Error, ("'", '"', "$")) for tok in parsed.flatten()) 117 | 118 | 119 | def parse_partial_identifier(word): 120 | """Attempt to parse a (partially typed) word as an identifier 121 | 122 | word may include a schema qualification, like `schema_name.partial_name` 123 | or `schema_name.` There may also be unclosed quotation marks, like 124 | `"schema`, or `schema."partial_name` 125 | 126 | :param word: string representing a (partially complete) identifier 127 | :return: sqlparse.sql.Identifier, or None 128 | """ 129 | 130 | p = sqlparse.parse(word)[0] 131 | n_tok = len(p.tokens) 132 | if n_tok == 1 and isinstance(p.tokens[0], Identifier): 133 | return p.tokens[0] 134 | if p.token_next_by(m=(Error, '"'))[1]: 135 | # An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar' 136 | # Close the double quote, then reparse 137 | return parse_partial_identifier(word + '"') 138 | return None 139 | 140 | -------------------------------------------------------------------------------- /odbcli/completion/prioritization.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | import re 4 | from collections import defaultdict 5 | import sqlparse 6 | from sqlparse.tokens import Name 7 | from .mssqlliterals.main import get_literals 8 | 9 | 10 | white_space_regex = re.compile('\\s+', re.MULTILINE) 11 | 12 | 13 | def _compile_regex(keyword): 14 | # Surround the keyword with word boundaries and replace interior whitespace 15 | # with whitespace wildcards 16 | pattern = '\\b' + white_space_regex.sub(r'\\s+', keyword) + '\\b' 17 | return re.compile(pattern, re.MULTILINE | re.IGNORECASE) 18 | 19 | 20 | keywords = get_literals('keywords') 21 | keyword_regexs = dict((kw, _compile_regex(kw)) for kw in keywords) 22 | 23 | 24 | class PrevalenceCounter: 25 | def __init__(self): 26 | self.keyword_counts = defaultdict(int) 27 | self.name_counts = defaultdict(int) 28 | 29 | def update(self, text): 30 | self.update_keywords(text) 31 | self.update_names(text) 32 | 33 | def update_names(self, text): 34 | for parsed in sqlparse.parse(text): 35 | for token in parsed.flatten(): 36 | if token.ttype in Name: 37 | self.name_counts[token.value] += 1 38 | 39 | def clear_names(self): 40 | self.name_counts = defaultdict(int) 41 | 42 | def update_keywords(self, text): 43 | # Count keywords. Can't rely for sqlparse for this, because it's 44 | # database agnostic 45 | for keyword, regex in keyword_regexs.items(): 46 | for _ in regex.finditer(text): 47 | self.keyword_counts[keyword] += 1 48 | 49 | def keyword_count(self, keyword): 50 | return self.keyword_counts[keyword] 51 | 52 | def name_count(self, name): 53 | return self.name_counts[name] 54 | -------------------------------------------------------------------------------- /odbcli/completion/sqlcompletion.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=too-many-instance-attributes 2 | 3 | from __future__ import print_function, unicode_literals 4 | from collections import namedtuple 5 | import re 6 | import sqlparse 7 | from sqlparse.sql import Comparison, Identifier, Where 8 | from sqlparse.tokens import Keyword, DML, Punctuation 9 | from .parseutils.utils import ( 10 | last_word, find_prev_keyword, parse_partial_identifier) 11 | from .parseutils.tables import extract_tables 12 | from .parseutils.ctes import isolate_query_ctes 13 | 14 | try: 15 | string_types = basestring # Python 2 16 | except NameError: 17 | string_types = str # Python 3 18 | 19 | 20 | Blank = namedtuple('Blank', []) 21 | Special = namedtuple('Special', []) 22 | NamedQuery = namedtuple('NamedQuery', []) 23 | Database = namedtuple('Database', []) 24 | Schema = namedtuple('Schema', 'parent quoted') 25 | Schema.__new__.__defaults__ = (None, False) 26 | # FromClauseItem is a table/view/function used in the FROM clause 27 | # `table_refs` contains the list of tables/... already in the statement, 28 | # used to ensure that the alias we suggest is unique 29 | FromClauseItem = namedtuple('FromClauseItem', 'grandparent parent table_refs local_tables') 30 | Table = namedtuple('Table', ['catalog', 'schema', 'table_refs', 'local_tables']) 31 | View = namedtuple('View', ['catalog', 'schema', 'table_refs']) 32 | # JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid' 33 | JoinCondition = namedtuple('JoinCondition', ['table_refs', 'parent']) 34 | # Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid' 35 | Join = namedtuple('Join', ['table_refs', 'schema', 'catalog']) 36 | 37 | Function = namedtuple('Function', ['catalog', 'schema', 'table_refs', 'usage']) 38 | # For convenience, don't require the `usage` argument in Function constructor 39 | Function.__new__.__defaults__ = (None, None, tuple(), None) 40 | Table.__new__.__defaults__ = (None, None, tuple(), tuple()) 41 | View.__new__.__defaults__ = (None, None, tuple()) 42 | FromClauseItem.__new__.__defaults__ = (None, None, tuple(), tuple()) 43 | 44 | Column = namedtuple( 45 | 'Column', 46 | ['table_refs', 'require_last_table', 'local_tables', 'qualifiable', 'context'] 47 | ) 48 | Column.__new__.__defaults__ = (None, None, tuple(), False, None) 49 | 50 | Keyword = namedtuple('Keyword', ['last_token']) 51 | Keyword.__new__.__defaults__ = (None,) 52 | Datatype = namedtuple('Datatype', ['schema']) 53 | Alias = namedtuple('Alias', ['aliases']) 54 | 55 | Path = namedtuple('Path', []) 56 | 57 | 58 | class SqlStatement: 59 | def __init__(self, full_text, text_before_cursor): 60 | self.identifier = None 61 | self.word_before_cursor = word_before_cursor = last_word( 62 | text_before_cursor, include='many_punctuations') 63 | full_text = _strip_named_query(full_text) 64 | text_before_cursor = _strip_named_query(text_before_cursor) 65 | 66 | full_text, text_before_cursor, self.local_tables = \ 67 | isolate_query_ctes(full_text, text_before_cursor) 68 | 69 | self.text_before_cursor_including_last_word = text_before_cursor 70 | 71 | # If we've partially typed a word then word_before_cursor won't be an 72 | # empty string. In that case we want to remove the partially typed 73 | # string before sending it to the sqlparser. Otherwise the last token 74 | # will always be the partially typed string which renders the smart 75 | # completion useless because it will always return the list of 76 | # keywords as completion. 77 | if self.word_before_cursor: 78 | if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\': 79 | parsed = sqlparse.parse(text_before_cursor) 80 | else: 81 | text_before_cursor = text_before_cursor[:- 82 | len(word_before_cursor)] 83 | parsed = sqlparse.parse(text_before_cursor) 84 | self.identifier = parse_partial_identifier(word_before_cursor) 85 | else: 86 | parsed = sqlparse.parse(text_before_cursor) 87 | 88 | full_text, text_before_cursor, parsed = \ 89 | _split_multiple_statements(full_text, text_before_cursor, parsed) 90 | 91 | self.full_text = full_text 92 | self.text_before_cursor = text_before_cursor 93 | self.parsed = parsed 94 | 95 | self.last_token = parsed.token_prev(len(parsed.tokens))[1] \ 96 | if parsed and parsed.token_prev(len(parsed.tokens))[1] else '' 97 | 98 | def is_insert(self): 99 | return self.parsed.token_first().value.lower() == 'insert' 100 | 101 | def get_tables(self, scope='full'): 102 | """ Gets the tables available in the statement. 103 | param `scope:` possible values: 'full', 'insert', 'before' 104 | If 'insert', only the first table is returned. 105 | If 'before', only tables before the cursor are returned. 106 | If not 'insert' and the stmt is an insert, the first table is skipped. 107 | """ 108 | tables = extract_tables( 109 | self.full_text if scope == 'full' else self.text_before_cursor) 110 | if scope == 'insert': 111 | tables = tables[:1] 112 | elif self.is_insert(): 113 | tables = tables[1:] 114 | return tables 115 | 116 | def get_previous_token(self, token): 117 | return self.parsed.token_prev(self.parsed.token_index(token))[1] 118 | 119 | def get_identifier_schema(self): 120 | schema = self.identifier.get_parent_name() \ 121 | if (self.identifier and self.identifier.get_parent_name()) else None 122 | # If schema name is unquoted, lower-case it 123 | if schema and self.identifier.value[0] != '"': 124 | schema = schema.lower() 125 | 126 | return schema 127 | 128 | def get_identifier_parents(self): 129 | if self.identifier is None: 130 | return None, None 131 | item_rev = Identifier(list(reversed(self.identifier.tokens))) 132 | name = item_rev._get_first_name(real_name = True) 133 | dot_idx, _ = item_rev.token_next_by(m=(Punctuation, '.')) 134 | if dot_idx is not None: 135 | schema_name = item_rev._get_first_name(dot_idx, real_name = True) 136 | dot_idx, _ = item_rev.token_next_by(m=(Punctuation, '.'), idx = dot_idx) 137 | if dot_idx is not None: 138 | catalog_name = item_rev._get_first_name(dot_idx, real_name = True) 139 | else: 140 | catalog_name = None 141 | else: 142 | schema_name = None 143 | catalog_name = None 144 | return catalog_name, schema_name 145 | 146 | def reduce_to_prev_keyword(self, n_skip=0): 147 | prev_keyword, self.text_before_cursor = \ 148 | find_prev_keyword(self.text_before_cursor, n_skip=n_skip) 149 | return prev_keyword 150 | 151 | 152 | def suggest_type(full_text, text_before_cursor): 153 | """Takes the full_text that is typed so far and also the text before the 154 | cursor to suggest completion type and scope. 155 | 156 | Returns a tuple with a type of entity ('table', 'column' etc) and a scope. 157 | A scope for a column category will be a list of tables. 158 | """ 159 | 160 | if full_text.startswith('\\i '): 161 | return (Path(),) 162 | 163 | # This is a temporary hack; the exception handling 164 | # here should be removed once sqlparse has been fixed 165 | try: 166 | stmt = SqlStatement(full_text, text_before_cursor) 167 | except (TypeError, AttributeError): 168 | return [] 169 | 170 | # # Check for special commands and handle those separately 171 | # if stmt.parsed: 172 | # # Be careful here because trivial whitespace is parsed as a 173 | # # statement, but the statement won't have a first token 174 | # tok1 = stmt.parsed.token_first() 175 | # if tok1 and tok1.value == '\\': 176 | # text = stmt.text_before_cursor + stmt.word_before_cursor 177 | # return suggest_special(text) 178 | 179 | return suggest_based_on_last_token(stmt.last_token, stmt) 180 | 181 | 182 | named_query_regex = re.compile(r'^\s*\\sn\s+[A-z0-9\-_]+\s+') 183 | 184 | 185 | def _strip_named_query(txt): 186 | """ 187 | This will strip "save named query" command in the beginning of the line: 188 | '\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' 189 | ' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' 190 | """ 191 | 192 | if named_query_regex.match(txt): 193 | txt = named_query_regex.sub('', txt) 194 | return txt 195 | 196 | 197 | function_body_pattern = re.compile(r'(\$.*?\$)([\s\S]*?)\1', re.M) 198 | 199 | 200 | def _find_function_body(text): 201 | split = function_body_pattern.search(text) 202 | return (split.start(2), split.end(2)) if split else (None, None) 203 | 204 | 205 | def _statement_from_function(full_text, text_before_cursor, statement): 206 | current_pos = len(text_before_cursor) 207 | body_start, body_end = _find_function_body(full_text) 208 | if body_start is None: 209 | return full_text, text_before_cursor, statement 210 | if not body_start <= current_pos < body_end: 211 | return full_text, text_before_cursor, statement 212 | full_text = full_text[body_start:body_end] 213 | text_before_cursor = text_before_cursor[body_start:] 214 | parsed = sqlparse.parse(text_before_cursor) 215 | return _split_multiple_statements(full_text, text_before_cursor, parsed) 216 | 217 | 218 | def _split_multiple_statements(full_text, text_before_cursor, parsed): 219 | if len(parsed) > 1: 220 | # Multiple statements being edited -- isolate the current one by 221 | # cumulatively summing statement lengths to find the one that bounds 222 | # the current position 223 | current_pos = len(text_before_cursor) 224 | stmt_start, stmt_end = 0, 0 225 | 226 | for statement in parsed: 227 | stmt_len = len(str(statement)) 228 | stmt_start, stmt_end = stmt_end, stmt_end + stmt_len 229 | 230 | if stmt_end >= current_pos: 231 | text_before_cursor = full_text[stmt_start:current_pos] 232 | full_text = full_text[stmt_start:] 233 | break 234 | 235 | elif parsed: 236 | # A single statement 237 | statement = parsed[0] 238 | else: 239 | # The empty string 240 | return full_text, text_before_cursor, None 241 | 242 | token2 = None 243 | if statement.get_type() in ('CREATE', 'CREATE OR REPLACE'): 244 | token1 = statement.token_first() 245 | if token1: 246 | token1_idx = statement.token_index(token1) 247 | token2 = statement.token_next(token1_idx)[1] 248 | if token2 and token2.value.upper() == 'FUNCTION': 249 | full_text, text_before_cursor, statement = _statement_from_function( 250 | full_text, text_before_cursor, statement 251 | ) 252 | return full_text, text_before_cursor, statement 253 | 254 | 255 | SPECIALS_SUGGESTION = { 256 | 'lf': Function, 257 | 'lt': Table, 258 | 'lv': View, 259 | 'sf': Function, 260 | } 261 | 262 | 263 | #def suggest_special(text): 264 | # text = text.lstrip() 265 | # cmd, _, arg = parse_special_command(text) 266 | # 267 | # if cmd == text: 268 | # # Trying to complete the special command itself 269 | # return (Special(),) 270 | # 271 | # if cmd in ('\\c', '\\connect'): 272 | # return (Database(),) 273 | # 274 | # if cmd == '\\ls': 275 | # return (Schema(),) 276 | # 277 | # if arg: 278 | # # Try to distinguish "\d name" from "\d schema.name" 279 | # # Note that this will fail to obtain a schema name if wildcards are 280 | # # used, e.g. "\d schema???.name" 281 | # parsed = sqlparse.parse(arg)[0].tokens[0] 282 | # try: 283 | # schema = parsed.get_parent_name() 284 | # except AttributeError: 285 | # schema = None 286 | # else: 287 | # schema = None 288 | # 289 | # if cmd[1:] == 'd': 290 | # # \d can describe tables or views 291 | # if schema: 292 | # return (Table(schema=schema), 293 | # View(schema=schema),) 294 | # return (Schema(), 295 | # Table(schema=None), 296 | # View(schema=None),) 297 | # if cmd[1:] in SPECIALS_SUGGESTION: 298 | # rel_type = SPECIALS_SUGGESTION[cmd[1:]] 299 | # if schema: 300 | # if rel_type == Function: 301 | # return (Function(schema=schema, usage='special'),) 302 | # return (rel_type(schema=schema),) 303 | # if rel_type == Function: 304 | # return (Schema(), Function(schema=None, usage='special'),) 305 | # return (Schema(), rel_type(schema=None)) 306 | # 307 | # if cmd in ['\\n', '\\sn', '\\dn']: 308 | # return (NamedQuery(),) 309 | # 310 | # return (Keyword(), Special()) 311 | 312 | 313 | def suggest_based_on_last_token(token, stmt): 314 | 315 | if isinstance(token, string_types): 316 | token_v = token.lower() 317 | elif isinstance(token, Comparison): 318 | # If 'token' is a Comparison type such as 319 | # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling 320 | # token.value on the comparison type will only return the lhs of the 321 | # comparison. In this case a.id. So we need to do token.tokens to get 322 | # both sides of the comparison and pick the last token out of that 323 | # list. 324 | token_v = token.tokens[-1].value.lower() 325 | elif isinstance(token, Where): 326 | # sqlparse groups all tokens from the where clause into a single token 327 | # list. This means that token.value may be something like 328 | # 'where foo > 5 and '. We need to look "inside" token.tokens to handle 329 | # suggestions in complicated where clauses correctly 330 | prev_keyword = stmt.reduce_to_prev_keyword() 331 | return suggest_based_on_last_token(prev_keyword, stmt) 332 | elif isinstance(token, Identifier): 333 | # If the previous token is an identifier, we can suggest datatypes if 334 | # we're in a parenthesized column/field list, e.g.: 335 | # CREATE TABLE foo (Identifier 336 | # CREATE FUNCTION foo (Identifier 337 | # If we're not in a parenthesized list, the most likely scenario is the 338 | # user is about to specify an alias, e.g.: 339 | # SELECT Identifier 340 | # SELECT foo FROM Identifier 341 | prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor) 342 | if prev_keyword and prev_keyword.value == '(': 343 | # Suggest datatypes 344 | return suggest_based_on_last_token('type', stmt) 345 | return (Keyword(),) 346 | else: 347 | token_v = token.value.lower() 348 | 349 | if not token: 350 | return (Keyword(), Special()) 351 | if token_v.endswith('('): 352 | p = sqlparse.parse(stmt.text_before_cursor)[0] 353 | 354 | if p.tokens and isinstance(p.tokens[-1], Where): 355 | # Four possibilities: 356 | # 1 - Parenthesized clause like "WHERE foo AND (" 357 | # Suggest columns/functions 358 | # 2 - Function call like "WHERE foo(" 359 | # Suggest columns/functions 360 | # 3 - Subquery expression like "WHERE EXISTS (" 361 | # Suggest keywords, in order to do a subquery 362 | # 4 - Subquery OR array comparison like "WHERE foo = ANY(" 363 | # Suggest columns/functions AND keywords. (If we wanted to be 364 | # really fancy, we could suggest only array-typed columns) 365 | 366 | column_suggestions = suggest_based_on_last_token('where', stmt) 367 | 368 | # Check for a subquery expression (cases 3 & 4) 369 | where = p.tokens[-1] 370 | prev_tok = where.token_prev(len(where.tokens) - 1)[1] 371 | 372 | if isinstance(prev_tok, Comparison): 373 | # e.g. "SELECT foo FROM bar WHERE foo = ANY(" 374 | prev_tok = prev_tok.tokens[-1] 375 | 376 | prev_tok = prev_tok.value.lower() 377 | if prev_tok == 'exists': 378 | return (Keyword(),) 379 | return column_suggestions 380 | 381 | # Get the token before the parens 382 | prev_tok = p.token_prev(len(p.tokens) - 1)[1] 383 | 384 | if (prev_tok and prev_tok.value and prev_tok.value.lower().split(' ')[-1] == 'using'): 385 | # tbl1 INNER JOIN tbl2 USING (col1, col2) 386 | tables = stmt.get_tables('before') 387 | 388 | # suggest columns that are present in more than one table 389 | return (Column(table_refs=tables, 390 | require_last_table=True, 391 | local_tables=stmt.local_tables),) 392 | 393 | if p.token_first().value.lower() == 'select': 394 | # If the lparen is preceeded by a space chances are we're about to 395 | # do a sub-select. 396 | if last_word(stmt.text_before_cursor, 397 | 'all_punctuations').startswith('('): 398 | return (Keyword(),) 399 | prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1] 400 | if prev_prev_tok and prev_prev_tok.normalized == 'INTO': 401 | return ( 402 | Column(table_refs=stmt.get_tables('insert'), context='insert'), 403 | ) 404 | # We're probably in a function argument list 405 | return (Column(table_refs=extract_tables(stmt.full_text), 406 | local_tables=stmt.local_tables, qualifiable=True),) 407 | if token_v == 'set': 408 | return (Column(table_refs=stmt.get_tables(), 409 | local_tables=stmt.local_tables),) 410 | if token_v in ('select', 'where', 'having', 'order by', 'distinct'): 411 | # Check for a table alias or schema qualification 412 | parent = stmt.identifier.get_parent_name() \ 413 | if (stmt.identifier and stmt.identifier.get_parent_name()) else [] 414 | tables = stmt.get_tables() 415 | if parent: 416 | tables = tuple(t for t in tables if identifies(parent, t)) 417 | return (Column(table_refs=tables, local_tables=stmt.local_tables), 418 | Table(schema=parent), 419 | View(schema=parent), 420 | Function(schema=parent),) 421 | return (Column(table_refs=tables, local_tables=stmt.local_tables, 422 | qualifiable=True), 423 | Function(schema=None), 424 | Keyword(token_v.upper()),) 425 | if token_v == 'as': 426 | # Don't suggest anything for aliases 427 | return () 428 | if (token_v.endswith('join') and token.is_keyword) or \ 429 | token_v in ('copy', 'from', 'update', 'into', 'describe', 'truncate'): 430 | 431 | catalog, schema = stmt.get_identifier_parents() 432 | tables = extract_tables(stmt.text_before_cursor) 433 | is_join = token_v.endswith('join') and token.is_keyword 434 | 435 | # Suggest tables from either the currently-selected schema or the 436 | # public schema if no schema has been specified 437 | suggest = [] 438 | 439 | if catalog is None and schema is None: 440 | suggest.insert(0, Database()) 441 | suggest.append(Schema()) 442 | elif not catalog: 443 | suggest.insert(0, Schema(parent = schema)) 444 | if token_v == 'from' or is_join: 445 | suggest.append(FromClauseItem(grandparent = catalog, 446 | parent = schema, 447 | table_refs=tables, 448 | local_tables=stmt.local_tables)) 449 | elif token_v == 'truncate': 450 | suggest.append(Table(catalog, schema)) 451 | else: 452 | suggest.extend((Table(catalog, schema), View(catalog, schema))) 453 | 454 | # TODO: Join(catalog = catalog, ...) 455 | if is_join and _allow_join(stmt.parsed): 456 | tables = stmt.get_tables('before') 457 | suggest.append(Join( 458 | table_refs = tables, 459 | schema=schema, 460 | catalog = catalog)) 461 | 462 | return tuple(suggest) 463 | 464 | # TODO: Use get_identifier_parents 465 | if token_v == 'function': 466 | schema = stmt.get_identifier_schema() 467 | # stmt.get_previous_token will fail for e.g. `SELECT 1 FROM functions 468 | # WHERE function:` 469 | try: 470 | prev = stmt.get_previous_token(token).value.lower() 471 | if prev in('drop', 'alter', 'create', 'create or replace'): 472 | return (Function(schema=schema, usage='signature'),) 473 | except ValueError: 474 | pass 475 | return tuple() 476 | 477 | if token_v in ('table', 'view'): 478 | # E.g. 'ALTER TABLE ' 479 | rel_type = { 480 | 'table': Table, 481 | 'view': View, 482 | 'function': Function}[token_v] 483 | catalog, schema = stmt.get_identifier_parents() 484 | suggest = [] 485 | if catalog is None and schema is None: 486 | suggest.insert(0, Database()) 487 | suggest.append(Schema()) 488 | elif not catalog: 489 | suggest.insert(0, Schema(parent = schema)) 490 | suggest.append(rel_type(catalog = catalog, schema = schema)) 491 | 492 | return suggest 493 | 494 | if token_v == 'column': 495 | # E.g. 'ALTER TABLE foo ALTER COLUMN bar 496 | return (Column(table_refs=stmt.get_tables()),) 497 | 498 | if token_v == 'on': 499 | tables = stmt.get_tables('before') 500 | parent = stmt.identifier.get_parent_name() \ 501 | if (stmt.identifier and stmt.identifier.get_parent_name()) else None 502 | if parent: 503 | # "ON parent." 504 | # parent can be either a schema name or table alias 505 | filteredtables = tuple(t for t in tables if identifies(parent, t)) 506 | sugs = [Column(table_refs=filteredtables, 507 | local_tables=stmt.local_tables), 508 | Table(schema=parent), 509 | View(schema=parent), 510 | Function(schema=parent)] 511 | if filteredtables and _allow_join_condition(stmt.parsed): 512 | sugs.append(JoinCondition(table_refs=tables, 513 | parent=filteredtables[-1])) 514 | return tuple(sugs) 515 | # ON 516 | # Use table alias if there is one, otherwise the table name 517 | aliases = tuple(t.ref for t in tables) 518 | if _allow_join_condition(stmt.parsed): 519 | return (Alias(aliases=aliases), JoinCondition( 520 | table_refs=tables, parent=None)) 521 | return (Alias(aliases=aliases),) 522 | 523 | if token_v in ('c', 'use', 'database', 'template'): 524 | # "\c ", "DROP DATABASE ", 525 | # "CREATE DATABASE WITH TEMPLATE " 526 | return (Database(),) 527 | if token_v == 'schema': 528 | # DROP SCHEMA schema_name, SET SCHEMA schema name 529 | prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2) 530 | quoted = prev_keyword and prev_keyword.value.lower() == 'set' 531 | return (Schema(quoted),) 532 | if token_v.endswith(',') or token_v in ('=', 'and', 'or'): 533 | prev_keyword = stmt.reduce_to_prev_keyword() 534 | if prev_keyword: 535 | return suggest_based_on_last_token(prev_keyword, stmt) 536 | return () 537 | if token_v in ('type', '::'): 538 | # ALTER TABLE foo SET DATA TYPE bar 539 | # SELECT foo::bar 540 | # Note that tables are a form of composite type in postgresql, so 541 | # they're suggested here as well 542 | # TODO: Use get_identifier_parents 543 | schema = stmt.get_identifier_schema() 544 | suggestions = [Datatype(schema=schema), 545 | Table(schema=schema)] 546 | if not schema: 547 | suggestions.append(Schema()) 548 | return tuple(suggestions) 549 | if token_v in {'alter', 'create', 'drop'}: 550 | return (Keyword(token_v.upper()),) 551 | if token_v in {'limit'}: 552 | return (Blank(),) 553 | if token.is_keyword: 554 | # token is a keyword we haven't implemented any special handling for 555 | # go backwards in the query until we find one we do recognize 556 | prev_keyword = stmt.reduce_to_prev_keyword(n_skip=1) 557 | if prev_keyword: 558 | return suggest_based_on_last_token(prev_keyword, stmt) 559 | return (Keyword(token_v.upper()),) 560 | return (Keyword(),) 561 | 562 | 563 | def identifies(string_id, ref): 564 | """Returns true if string `string_id` matches TableReference `ref`""" 565 | 566 | return string_id == ref.alias or string_id == ref.name or ( 567 | ref.schema and (string_id == ref.schema + '.' + ref.name)) 568 | 569 | 570 | def _allow_join_condition(statement): 571 | """ 572 | Tests if a join condition should be suggested 573 | 574 | We need this to avoid bad suggestions when entering e.g. 575 | select * from tbl1 a join tbl2 b on a.id = 576 | So check that the preceding token is a ON, AND, or OR keyword, instead of 577 | e.g. an equals sign. 578 | 579 | :param statement: an sqlparse.sql.Statement 580 | :return: boolean 581 | """ 582 | 583 | if not statement or not statement.tokens: 584 | return False 585 | 586 | last_tok = statement.token_prev(len(statement.tokens))[1] 587 | return last_tok.value.lower() in ('on', 'and', 'or') 588 | 589 | 590 | def _allow_join(statement): 591 | """ 592 | Tests if a join should be suggested 593 | 594 | We need this to avoid bad suggestions when entering e.g. 595 | select * from tbl1 a join tbl2 b 596 | So check that the preceding token is a JOIN keyword 597 | 598 | :param statement: an sqlparse.sql.Statement 599 | :return: boolean 600 | """ 601 | 602 | if not statement or not statement.tokens: 603 | return False 604 | 605 | last_tok = statement.token_prev(len(statement.tokens))[1] 606 | return (last_tok.value.lower().endswith('join') and 607 | last_tok.value.lower() not in('cross join', 'natural join')) 608 | 609 | -------------------------------------------------------------------------------- /odbcli/config.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import shutil 3 | import os 4 | import platform 5 | from os.path import expanduser, exists, dirname 6 | from configobj import ConfigObj 7 | 8 | 9 | def config_location(): 10 | if "XDG_CONFIG_HOME" in os.environ: 11 | return "%s/odbcli/" % expanduser(os.environ["XDG_CONFIG_HOME"]) 12 | elif platform.system() == "Windows": 13 | return os.getenv("USERPROFILE") + "\\AppData\\Local\\dbcli\\odbcli\\" 14 | else: 15 | return expanduser("~/.config/odbcli/") 16 | 17 | 18 | def load_config(usr_cfg, def_cfg=None): 19 | cfg = ConfigObj() 20 | cfg.merge(ConfigObj(def_cfg, interpolation=False)) 21 | cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8")) 22 | cfg.filename = expanduser(usr_cfg) 23 | 24 | return cfg 25 | 26 | 27 | def ensure_dir_exists(path): 28 | parent_dir = expanduser(dirname(path)) 29 | os.makedirs(parent_dir, exist_ok=True) 30 | 31 | 32 | def write_default_config(source, destination, overwrite=False): 33 | destination = expanduser(destination) 34 | if not overwrite and exists(destination): 35 | return 36 | 37 | ensure_dir_exists(destination) 38 | 39 | shutil.copyfile(source, destination) 40 | 41 | 42 | def upgrade_config(config, def_config): 43 | cfg = load_config(config, def_config) 44 | cfg.write() 45 | 46 | 47 | def get_config(odbclirc_file=None): 48 | from odbcli import __file__ as package_root 49 | 50 | package_root = os.path.dirname(package_root) 51 | 52 | odbclirc_file = odbclirc_file or "%sconfig" % config_location() 53 | 54 | default_config = os.path.join(package_root, "odbclirc") 55 | write_default_config(default_config, odbclirc_file) 56 | 57 | return load_config(odbclirc_file, default_config) 58 | -------------------------------------------------------------------------------- /odbcli/conn.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from cyanodbc import connect, Connection, SQLGetInfo, Cursor, DatabaseError, ConnectError 3 | from typing import Optional 4 | from cli_helpers.tabular_output import TabularOutputFormatter 5 | from logging import getLogger 6 | from re import sub 7 | from threading import Lock, Event, Thread, Condition 8 | from asyncio import get_event_loop 9 | from enum import IntEnum 10 | from .dbmetadata import DbMetadata 11 | 12 | formatter = TabularOutputFormatter() 13 | 14 | class connStatus(Enum): 15 | DISCONNECTED = 0 16 | IDLE = 1 17 | EXECUTING = 2 18 | FETCHING = 3 19 | ERROR = 4 20 | 21 | class executionStatus(IntEnum): 22 | OK = 0 23 | FAIL = 1 24 | OKWRESULTS = 2 25 | 26 | class sqlConnection: 27 | def __init__( 28 | self, 29 | dsn: str, 30 | conn: Optional[Connection] = Connection(), 31 | username: Optional[str] = "", 32 | password: Optional[str] = "" 33 | ) -> None: 34 | self.dsn = dsn 35 | self.conn = conn 36 | self.cursor: Cursor = None 37 | self.query: str = None 38 | self.username = username 39 | self.password = password 40 | self.logger = getLogger(__name__) 41 | self.dbmetadata = DbMetadata() 42 | self._quotechar = None 43 | self._search_escapechar = None 44 | self._search_escapepattern = None 45 | # Lock to be held by database interaction that happens 46 | # in the main process. Recall, main-buffer as well as preview 47 | # buffer queries get executed in a separate process, however 48 | # auto-completion, as well as object browser expansion happen 49 | # in the main process possibly multi-threaded. Multi threaded is fine 50 | # we don't want the main process to lock-up while writing a query, 51 | # however, we don't want to potentially hammer the connection with 52 | # multiple auto-completion result queries before each has had a chance 53 | # to return. 54 | self._lock = Lock() 55 | 56 | # Lock to be held when updating self.status, which can happen from 57 | # a thread 58 | self._status_lock = Lock() 59 | # Lock to be held when updating self._execution_status_lock which 60 | # can happen from a thread 61 | self._execution_status_lock = Lock() 62 | 63 | # Lock that protects interaction with _fetch_res 64 | self._fetch_cv = Condition() 65 | # This is the list that carries the cache of retrieved rows via the 66 | # asynchronous fetch operation 67 | self._fetch_res: list = [] 68 | self._fetch_thread = Thread() 69 | self._execution_thread = Thread() 70 | self._cancel_async_event = Event() 71 | 72 | self._status = connStatus.DISCONNECTED 73 | self._execution_status: executionStatus = executionStatus.OK 74 | self._execution_err: str = None 75 | 76 | @property 77 | def execution_status(self) -> executionStatus: 78 | """ Hold the lock here since it gets assigned in execute 79 | which can be called in a different thread """ 80 | with self._execution_status_lock: 81 | res = self._execution_status 82 | return res 83 | 84 | @property 85 | def status(self) -> connStatus: 86 | """ Hold the lock here since it can be assigned in more than one 87 | thread """ 88 | with self._status_lock: 89 | res = self._status 90 | return res 91 | 92 | @property 93 | def execution_err(self) -> str: 94 | """ Last execution error: Cleared prior to every execution. 95 | Hold the lock here since it gets assigned in execute 96 | which can be called in a different thread """ 97 | with self._lock: 98 | res = self._execution_err 99 | return res 100 | 101 | @property 102 | def quotechar(self) -> str: 103 | if self._quotechar is None: 104 | self._quotechar = self.conn.get_info( 105 | SQLGetInfo.SQL_IDENTIFIER_QUOTE_CHAR) 106 | # pyodbc note 107 | # self._quotechar = self.conn.getinfo( 108 | return self._quotechar 109 | 110 | @property 111 | def search_escapechar(self) -> str: 112 | if self._search_escapechar is None: 113 | self._search_escapechar = self.conn.get_info( 114 | SQLGetInfo.SQL_SEARCH_PATTERN_ESCAPE) 115 | return self._search_escapechar 116 | 117 | @property 118 | def search_escapepattern(self) -> str: 119 | if self._search_escapepattern is None: 120 | # https://stackoverflow.com/questions/2428117/casting-raw-strings-python 121 | self._search_escapepattern = \ 122 | (self.search_escapechar).encode("unicode-escape").decode() 123 | 124 | return self._search_escapepattern 125 | 126 | def update_status(self, status: connStatus = connStatus.IDLE) -> None: 127 | """ Thread safe way of updating the connection status 128 | """ 129 | with self._status_lock: 130 | self._status = status 131 | 132 | def update_execution_status(self, status: executionStatus = executionStatus.OK) -> None: 133 | """ Thread safe way of updating the execution status 134 | """ 135 | with self._execution_status_lock: 136 | self._execution_status = status 137 | 138 | def sanitize_search_string(self, term) -> str: 139 | if term is not None and len(term): 140 | res = sub("(_|%)", self.search_escapepattern + "\\1", term) 141 | else: 142 | res = term 143 | return res 144 | 145 | def unsanitize_search_string(self, term) -> str: 146 | if term is not None and len(term): 147 | res = sub(self.search_escapepattern, "", term) 148 | else: 149 | res = term 150 | return res 151 | 152 | def escape_name(self, name): 153 | if name: 154 | qtchar = self.quotechar 155 | name = (qtchar + "%s" + qtchar) % name 156 | return name 157 | 158 | def escape_names(self, names): 159 | return [self.escape_name(name) for name in names] 160 | 161 | def unescape_name(self, name): 162 | """ Unquote a string.""" 163 | if name: 164 | qtchar = self.quotechar 165 | if name and name[0] == qtchar and name[-1] == qtchar: 166 | name = name[1:-1] 167 | 168 | return name 169 | 170 | def connect( 171 | self, 172 | username: str = "", 173 | password: str = "", 174 | force: bool = False) -> None: 175 | uid = username or self.username 176 | pwd = password or self.password 177 | conn_str = "DSN=" + self.dsn + ";" 178 | if len(uid): 179 | self.username = uid 180 | conn_str = conn_str + "UID=" + uid + ";" 181 | if len(pwd): 182 | self.password = pwd 183 | conn_str = conn_str + "PWD=" + pwd + ";" 184 | if force or not self.conn.connected(): 185 | try: 186 | self.conn = connect(dsn = conn_str, timeout = 5) 187 | self.update_status(connStatus.IDLE) 188 | except ConnectError as e: 189 | self.logger.error("Error while connecting: %s", str(e)) 190 | raise ConnectError(e) 191 | 192 | def fetchmany(self, size) -> list: 193 | with self._lock: 194 | if self.cursor: 195 | # This gets called in a thread / so exceptions can get lost 196 | # Make sure to recover after cyanodbc errors so that we can 197 | # complete the tail part of the worker (status/notify) 198 | try: 199 | res = self.cursor.fetchmany(size) 200 | except DatabaseError as e: 201 | self.logger.warning("Error while fetching: %s", str(e)) 202 | res = [] 203 | else: 204 | res = [] 205 | 206 | if len(res) < 1: 207 | self.update_status(connStatus.IDLE) 208 | with self._fetch_cv: 209 | self._fetch_res.extend(res) 210 | self._fetch_cv.notify() 211 | 212 | return res 213 | 214 | def async_fetchall(self, size, app) -> None: 215 | """ True asynchronous fetch. Will start a fetch, that will fetch 216 | *all* results (in chunks of size = size) in a background operation 217 | until the result set is depleted or we signal a stop via 218 | _cancel_async_event. After thread operation is completed, 219 | it asks the running event loop to redraw the app to pick up the 220 | new connection status (IDLE). 221 | """ 222 | self._cancel_async_event.clear() 223 | self.update_status(connStatus.FETCHING) 224 | loop = get_event_loop() 225 | def _run(): 226 | while True: 227 | res = self.fetchmany(size) 228 | if len(res) < 1 or self._cancel_async_event.is_set(): 229 | self.update_status(connStatus.IDLE) 230 | self._cancel_async_event.clear() 231 | # Should we try close cursor here? Problem is that 232 | # close_curor attempts to call cancel async which would 233 | # block until this thread is over 234 | break 235 | loop.call_soon_threadsafe(app.invalidate) 236 | return 237 | 238 | self._fetch_thread = Thread(target = _run, daemon = True) 239 | self._fetch_thread.start() 240 | return 241 | 242 | def fetch_from_cache(self, size, wait = False) -> list: 243 | """ Will grab the first size elements from self._fetch_res. Recall 244 | self._fetch_res is the result cache that is built up via an async 245 | fetch that grabs rows in chunks. Here, in a threadsafe manner we 246 | wait for a the asynchronous method to grab enough elements or 247 | finish the fetch operation altogether, then we 'pop' from the 248 | fetch result cache. 249 | """ 250 | with self._fetch_cv: 251 | if wait: 252 | self._fetch_cv.wait_for( 253 | lambda: len(self._fetch_res) > size or self.status == connStatus.IDLE) 254 | res = self._fetch_res[:size] 255 | del self._fetch_res[:size] 256 | return res 257 | 258 | 259 | def cancel_async_fetchall(self) -> None: 260 | """ Signal fetching thread to terminate operation, then wait / block 261 | until thread terminates. Also clear the fetch result cache. 262 | """ 263 | self.logger.debug("cancel_async_fetchall ...") 264 | self._cancel_async_event.set() 265 | if self._fetch_thread.is_alive(): 266 | self._fetch_thread.join() 267 | with self._fetch_cv: 268 | self._fetch_res = [] 269 | 270 | def execute(self, query, parameters = None, event: Event = None) -> Cursor: 271 | self.logger.debug("Execute: %s", query) 272 | with self._lock: 273 | self.cursor = self.conn.cursor() 274 | try: 275 | self._execution_err = None 276 | self.update_status(connStatus.EXECUTING) 277 | self.cursor.execute(query, parameters) 278 | self.update_status(connStatus.IDLE) 279 | self.update_execution_status(executionStatus.OK) 280 | self.query = query 281 | except DatabaseError as e: 282 | self.update_status(connStatus.IDLE) 283 | self.update_execution_status(executionStatus.FAIL) 284 | self._execution_err = str(e) 285 | self.logger.warning("Execution error: %s", str(e)) 286 | if event is not None: 287 | event.set() 288 | return self.cursor 289 | 290 | def async_execute(self, query) -> Cursor: 291 | """ async_ is a misnomer here. It does execute in a new thread 292 | however it will also wait for execution to complete. At this time 293 | this helps us with registering KeyboardInterrupt during cyanodbc. 294 | execute only; it may evolve to have more true async-like behavior. 295 | """ 296 | self.close_cursor() 297 | exec_event = Event() 298 | self._execution_thread = Thread( 299 | target = self.execute, 300 | kwargs = {"query": query, "parameters": None, "event": exec_event}, 301 | daemon = True) 302 | self._execution_thread.start() 303 | # Will block but can be interrupted 304 | exec_event.wait() 305 | return self.cursor 306 | 307 | def list_catalogs(self) -> list: 308 | # pyodbc note 309 | # return conn.cursor().tables(catalog = "%").fetchall() 310 | res = [] 311 | if self.status != connStatus.IDLE: 312 | return res 313 | try: 314 | if self.conn.connected(): 315 | self.logger.debug("Calling list_catalogs...") 316 | with self._lock: 317 | res = self.conn.list_catalogs() 318 | self.logger.debug("list_catalogs: done") 319 | except DatabaseError as e: 320 | self.update_status(connStatus.ERROR) 321 | self.logger.warning("list_catalogs: %s", str(e)) 322 | 323 | return res 324 | 325 | def list_schemas(self, catalog = None) -> list: 326 | res = [] 327 | 328 | # We only trust this generic implementation if attempting to list 329 | # schemata in curent catalog (or catalog argument is None) 330 | if catalog is not None and not catalog == self.current_catalog(): 331 | return res 332 | 333 | if self.status != connStatus.IDLE: 334 | return res 335 | try: 336 | if self.conn.connected(): 337 | self.logger.debug("Calling list_schemas...") 338 | with self._lock: 339 | res = self.conn.list_schemas() 340 | self.logger.debug("list_schemas: done") 341 | except DatabaseError as e: 342 | self.update_status(connStatus.ERROR) 343 | self.logger.warning("list_schemas: %s", str(e)) 344 | 345 | return res 346 | 347 | def find_tables( 348 | self, 349 | catalog = "", 350 | schema = "", 351 | table = "", 352 | type = "") -> list: 353 | res = [] 354 | 355 | if self.status != connStatus.IDLE: 356 | return res 357 | try: 358 | if self.conn.connected(): 359 | self.logger.debug("Calling find_tables: %s, %s, %s, %s", 360 | catalog, schema, table, type) 361 | with self._lock: 362 | res = self.conn.find_tables( 363 | catalog = catalog, 364 | schema = schema, 365 | table = table, 366 | type = type) 367 | self.logger.debug("find_tables: done") 368 | except DatabaseError as e: 369 | self.logger.warning("find_tables: %s.%s.%s, type %s: %s", catalog, schema, table, type, str(e)) 370 | 371 | return res 372 | 373 | def find_columns( 374 | self, 375 | catalog = "", 376 | schema = "", 377 | table = "", 378 | column = "") -> list: 379 | res = [] 380 | if self.status != connStatus.IDLE: 381 | return res 382 | 383 | try: 384 | if self.conn.connected(): 385 | self.logger.debug("Calling find_columns: %s, %s, %s, %s", 386 | catalog, schema, table, column) 387 | with self._lock: 388 | res = self.conn.find_columns( 389 | catalog = catalog, 390 | schema = schema, 391 | table = table, 392 | column = column) 393 | self.logger.debug("find_columns: done") 394 | except DatabaseError as e: 395 | self.logger.warning("find_columns: %s.%s.%s, column %s: %s", catalog, schema, table, column, str(e)) 396 | 397 | return res 398 | 399 | def find_procedures( 400 | self, 401 | catalog = "", 402 | schema = "", 403 | procedure = "") -> list: 404 | res = [] 405 | 406 | if self.status != connStatus.IDLE: 407 | return res 408 | 409 | try: 410 | if self.conn.connected(): 411 | self.logger.debug("Calling find_procedures: %s, %s, %s", 412 | catalog, schema, procedure) 413 | with self._lock: 414 | res = self.conn.find_procedures( 415 | catalog = catalog, 416 | schema = schema, 417 | procedure = procedure) 418 | self.logger.debug("find_procedures: done") 419 | except DatabaseError as e: 420 | self.logger.warning("find_procedures: %s.%s.%s: %s", catalog, schema, procedure, str(e)) 421 | 422 | return res 423 | 424 | def find_procedure_columns( 425 | self, 426 | catalog = "", 427 | schema = "", 428 | procedure = "", 429 | column = "") -> list: 430 | res = [] 431 | 432 | if self.status != connStatus.IDLE: 433 | return res 434 | 435 | try: 436 | if self.conn.connected(): 437 | self.logger.debug("Calling find_procedure_columns: %s, %s, %s, %s", 438 | catalog, schema, procedure, column) 439 | with self._lock: 440 | res = self.conn.find_procedure_columns( 441 | catalog = catalog, 442 | schema = schema, 443 | procedure = procedure, 444 | column = column) 445 | self.logger.debug("find_procedure_columns: done") 446 | except DatabaseError as e: 447 | self.logger.warning("find_procedure_columns: %s.%s.%s, column %s: %s", catalog, schema, procedure, column, str(e)) 448 | 449 | return res 450 | 451 | def current_catalog(self) -> str: 452 | if self.conn.connected(): 453 | return self.conn.catalog_name 454 | return None 455 | 456 | def connected(self) -> bool: 457 | return self.conn.connected() 458 | 459 | def catalog_support(self) -> bool: 460 | res = self.conn.get_info(SQLGetInfo.SQL_CATALOG_NAME) 461 | return res == True or res == 'Y' 462 | # pyodbc note 463 | # return self.conn.getinfo(pyodbc.SQL_CATALOG_NAME) == True or self.conn.getinfo(pyodbc.SQL_CATALOG_NAME) == 'Y' 464 | 465 | def get_info(self, code: int) -> str: 466 | return self.conn.get_info(code) 467 | 468 | def close(self) -> None: 469 | self.logger.debug("close ...") 470 | # TODO: When disconnecting 471 | # We likely don't want to allow any exception to 472 | # propagate. Catch DatabaseError? 473 | if self.conn.connected(): 474 | self.conn.close() 475 | 476 | def close_cursor(self) -> None: 477 | self.logger.debug("Close cursor ...") 478 | self.cancel_async_fetchall() 479 | if self.cursor: 480 | with self._lock: 481 | self.cursor.close() 482 | self.cursor = None 483 | self.query = None 484 | self.update_status(connStatus.IDLE) 485 | 486 | def cancel(self) -> None: 487 | self.logger.debug("cancel ...") 488 | self.cancel_async_fetchall() 489 | if self.cursor: 490 | # Should not hold _lock here. Point here is to cancel execution 491 | # that might be taking place in a separate thread where the execution 492 | # lock is being held 493 | self.cursor.cancel() 494 | if self._execution_thread.is_alive(): 495 | self._execution_thread.join() 496 | self.query = None 497 | self.update_status(connStatus.IDLE) 498 | 499 | def preview_query( 500 | self, 501 | name, 502 | obj_type = "table", 503 | filter_query = "", 504 | limit = -1) -> str: 505 | """ Currently we only have a generic implementation for tables and 506 | views. Otherwise (functions) we return None 507 | """ 508 | if obj_type == "table" or obj_type == "view": 509 | qry = "SELECT * FROM " + name + " " + filter_query 510 | if limit > 0: 511 | qry = qry + " LIMIT " + str(limit) 512 | else: 513 | qry = None 514 | return qry 515 | 516 | def formatted_fetch(self, size, cols, format_name = "psql"): 517 | while True: 518 | res = self.fetch_from_cache(size, wait = True) 519 | if len(res) < 1 and self.status != connStatus.FETCHING: 520 | break 521 | if len(res) > 0: 522 | yield "\n".join( 523 | formatter.format_output( 524 | res, 525 | cols, 526 | format_name = format_name)) 527 | 528 | connWrappers = {} 529 | 530 | class MSSQL(sqlConnection): 531 | def find_tables( 532 | self, 533 | catalog = "", 534 | schema = "", 535 | table = "", 536 | type = "") -> list: 537 | """ FreeTDS does not allow us to query catalog == '', and 538 | schema = '' which, according to the ODBC spec for SQLTables should 539 | return tables outside of any catalog/schema. In the case of FreeTDS 540 | what gets passed to the sp_tables sproc is null, which in turn 541 | is interpreted as a wildcard. For the time being intercept 542 | these queries here (used in auto completion) and return empty 543 | set. """ 544 | 545 | if catalog == "\x00" and schema == "\x00": 546 | return [] 547 | 548 | return super().find_tables( 549 | catalog = catalog, 550 | schema = schema, 551 | table = table, 552 | type = type) 553 | 554 | def list_schemas(self, catalog = None) -> list: 555 | """ Optimization for listing out-of-database schemas by 556 | always querying catalog.sys.schemas. """ 557 | res = [] 558 | if self.status != connStatus.IDLE: 559 | return res 560 | 561 | qry = "SELECT name FROM {catalog}.sys.schemas " \ 562 | "WHERE name NOT IN ('db_owner', 'db_accessadmin', " \ 563 | "'db_securityadmin', 'db_ddladmin', 'db_backupoperator', " \ 564 | "'db_datareader', 'db_datawriter', 'db_denydatareader', " \ 565 | "'db_denydatawriter')" 566 | 567 | if catalog is None and self.current_catalog(): 568 | catalog_local = self.current_catalog() 569 | else: 570 | # We are going to be outright executing, versus 571 | # using the ODBC API. 572 | # let's make sure there is nothing escaped here 573 | catalog_local = self.unsanitize_search_string(catalog) 574 | 575 | if catalog_local: 576 | try: 577 | self.logger.debug("Calling list_schemas...") 578 | crsr = self.execute(qry.format(catalog = catalog_local)) 579 | res = crsr.fetchall() 580 | crsr.close() 581 | self.logger.debug("Calling list_schemas: done") 582 | schemas = [r[0] for r in res] 583 | if len(schemas): 584 | return schemas 585 | except DatabaseError as e: 586 | # execute has an exception handler, but the cursor calls may 587 | # throw 588 | self.close_cursor() 589 | self.logger.warning("MSSQL list_schemas: %s", str(e)) 590 | 591 | return super().list_schemas(catalog = catalog) 592 | 593 | def preview_query( 594 | self, 595 | name, 596 | obj_type = "table", 597 | filter_query = "", 598 | limit = -1) -> str: 599 | 600 | if obj_type == "table" or obj_type == "view": 601 | qry = " * FROM " + name + " " + filter_query 602 | if limit > 0: 603 | qry = "SELECT TOP " + str(limit) + qry 604 | else: 605 | qry = "SELECT" + qry 606 | elif obj_type == "function": 607 | # Sproc names in SQLServer come back with 608 | # catalog.schema.name;INT with the trailing suffix 609 | # not useful 610 | name_sanitized = sub("(;\\d{0,})(\")$", "\\2", name) 611 | catalog_local = sub("(.*)\\.(.*)\\.(.*)", "\\1", name) 612 | qry = "SELECT definition FROM {catalog}.sys.sql_modules " \ 613 | "WHERE object_id = (OBJECT_ID(N'{name}'))" 614 | qry = qry.format(catalog = catalog_local, name = name_sanitized) 615 | else: 616 | qry = None 617 | 618 | return qry 619 | 620 | class PSSQL(sqlConnection): 621 | def find_tables( 622 | self, 623 | catalog = "", 624 | schema = "", 625 | table = "", 626 | type = "") -> list: 627 | """ At least the psql odbc driver I am using has an annoying habbit 628 | of treating the catalog and schema fields interchangible, which 629 | in turn screws up with completion""" 630 | 631 | if not catalog in [self.current_catalog(), self.sanitize_search_string(self.current_catalog())]: 632 | return [] 633 | 634 | return super().find_tables( 635 | catalog = catalog, 636 | schema = schema, 637 | table = table, 638 | type = type) 639 | 640 | def find_procedures( 641 | self, 642 | catalog = "", 643 | schema = "", 644 | procedure = "") -> list: 645 | """ At least the psql odbc driver I am using has an annoying habbit 646 | of treating the catalog and schema fields interchangible, which 647 | in turn screws up with completion""" 648 | 649 | if not catalog in [self.current_catalog(), self.sanitize_search_string(self.current_catalog())]: 650 | return [] 651 | 652 | return super().find_procedures( 653 | catalog = catalog, 654 | schema = schema, 655 | procedure = procedure) 656 | 657 | def find_columns( 658 | self, 659 | catalog = "", 660 | schema = "", 661 | table = "", 662 | column = "") -> list: 663 | """ At least the psql odbc driver I am using has an annoying habbit 664 | of treating the catalog and schema fields interchangible, which 665 | in turn screws up with completion""" 666 | 667 | if not catalog in [self.current_catalog(), self.sanitize_search_string(self.current_catalog())]: 668 | return [] 669 | 670 | return super().find_columns( 671 | catalog = catalog, 672 | schema = schema, 673 | table = table, 674 | column = column) 675 | 676 | def find_procedure_columns( 677 | self, 678 | catalog = "", 679 | schema = "", 680 | procedure = "", 681 | column = "") -> list: 682 | """ At least the psql odbc driver I am using has an annoying habbit 683 | of treating the catalog and schema fields interchangible, which 684 | in turn screws up with completion. In addition wildcards in the column 685 | field, seem to not work - but an empty string does.""" 686 | 687 | if not catalog in [self.current_catalog(), self.sanitize_search_string(self.current_catalog())]: 688 | return [] 689 | 690 | if column == "%": 691 | column = "" 692 | 693 | return super().find_procedure_columns( 694 | catalog = catalog, 695 | schema = schema, 696 | procedure = procedure, 697 | column = column) 698 | 699 | class MySQL(sqlConnection): 700 | 701 | def list_schemas(self, catalog = None) -> list: 702 | """ Only catalogs for MySQL, it seems, 703 | however, list_schemas returns [""] which 704 | causes blank entries to show up in auto 705 | completion. Also confuses some of the checks we have 706 | that look for len(list_schemas) < 1 to decide whether 707 | to fall-back to find_tables. Make sure that for MySQL 708 | we do, in-fact fall-back to find_tables""" 709 | return [] 710 | 711 | def current_catalog(self) -> str: 712 | if self.conn.connected(): 713 | res = self.conn.catalog_name 714 | if res == "null": 715 | res = "" 716 | return res 717 | 718 | class Snowflake(sqlConnection): 719 | 720 | def find_tables( 721 | self, 722 | catalog = "", 723 | schema = "", 724 | table = "", 725 | type = "") -> list: 726 | 727 | type = type.upper() 728 | return super().find_tables( 729 | catalog = catalog, 730 | schema = schema, 731 | table = table, 732 | type = type) 733 | 734 | connWrappers["MySQL"] = MySQL 735 | connWrappers["Microsoft SQL Server"] = MSSQL 736 | connWrappers["PostgreSQL"] = PSSQL 737 | connWrappers["Snowflake"] = Snowflake 738 | -------------------------------------------------------------------------------- /odbcli/dbmetadata.py: -------------------------------------------------------------------------------- 1 | from threading import Lock, Event, Thread 2 | 3 | class DbMetadata(): 4 | def __init__(self) -> None: 5 | self._lock = Lock() 6 | self._dbmetadata = {'table': {}, 'view': {}, 'function': {}, 7 | 'datatype': {}} 8 | 9 | def extend_catalogs(self, names: list) -> None: 10 | with self._lock: 11 | for metadata in self._dbmetadata.values(): 12 | for catalog in names: 13 | metadata[catalog.lower()] = (catalog, {}) 14 | return 15 | 16 | def get_catalogs(self, obj_type: str = "table", cased: bool = True) -> list: 17 | """ Retrieve catalogs as the keys for _dbmetadata[obj_type] 18 | If no keys are found it returns None. 19 | """ 20 | with self._lock: 21 | if cased: 22 | res = [casedkey for casedkey, mappedvalue in self._dbmetadata[obj_type].values()] 23 | else: 24 | res = list(self._dbmetadata[obj_type].keys()) 25 | 26 | 27 | if len(res) == 0: 28 | return None 29 | 30 | return res 31 | 32 | def extend_schemas(self, catalog, names: list) -> None: 33 | """ This method will force/create [catalog] dictionary 34 | in the event that len(names) > 0, and overwrite if 35 | anything was there to begin with. 36 | """ 37 | catlower = catalog.lower() 38 | cat_cased = catalog 39 | if len(names): 40 | with self._lock: 41 | for metadata in self._dbmetadata.values(): 42 | # Preserve casing if an entry already there 43 | if catlower in metadata.keys() and len(metadata[catlower]): 44 | cat_cased = metadata[catlower][0] 45 | metadata[catlower] = (cat_cased, {}) 46 | for schema in names: 47 | metadata[catlower][1][schema.lower()] = (schema, {}) 48 | return 49 | 50 | def get_schemas(self, catalog: str, obj_type: str = "table", cased: bool = True) -> list: 51 | """ Retrieve schemas as the keys for _dbmetadata[obj_type][catalog] 52 | If catalog is not part of the _dbmetadata[obj_type] keys will return 53 | None. 54 | """ 55 | 56 | catlower = catalog.lower() 57 | cats = self.get_catalogs(obj_type = obj_type, cased = False) 58 | if cats is None or catlower not in cats: 59 | return None 60 | 61 | with self._lock: 62 | if cased: 63 | res = [casedkey for casedkey, mappedvalue in self._dbmetadata[obj_type][catlower][1].values()] 64 | else: 65 | res = list(self._dbmetadata[obj_type][catlower][1].keys()) 66 | 67 | 68 | return res 69 | 70 | def extend_objects(self, catalog, schema, names: list, obj_type: str) -> None: 71 | catlower = catalog.lower() 72 | schlower = schema.lower() 73 | if len(names): 74 | with self._lock: 75 | for otype in self._dbmetadata.keys(): 76 | # Loop over tables, views, functions 77 | if catlower not in self._dbmetadata[otype].keys(): 78 | self._dbmetadata[otype][catlower] = (catalog, {}) 79 | if schlower not in self._dbmetadata[otype][catlower][1].keys(): 80 | self._dbmetadata[otype][catlower][1][schlower] = (schema, {}) 81 | for obj in names: 82 | self._dbmetadata[obj_type][catlower][1][schlower][1][obj.lower()] = (obj, {}) 83 | # If we passed nothing then take out that element entirely out 84 | # of the dict 85 | else: 86 | with self._lock: 87 | del self._dbmetadata[obj_type][catlower][1][schlower] 88 | 89 | return 90 | 91 | def get_objects(self, catalog: str, schema: str, obj_type: str = "table") -> list: 92 | """ Retrieve objects as the keys for _dbmetadata[obj_type][catalog][schema] 93 | If catalog is not part of the _dbmetadata[obj_type] keys, or schema 94 | not one of the keys in _dbmetadata[obj_type][catalog] will return None 95 | """ 96 | 97 | catlower = catalog.lower() 98 | schlower = schema.lower() 99 | schemas = self.get_schemas(catalog = catalog, obj_type = obj_type, cased = False) 100 | if schemas is None or schlower not in schemas: 101 | return None 102 | 103 | with self._lock: 104 | res = [casedkey for casedkey, mappedvalue in self._dbmetadata[obj_type][catlower][1][schlower][1].values()] 105 | 106 | return list(res) 107 | 108 | def reset_metadata(self) -> None: 109 | with self._lock: 110 | self._dbmetadata = {'table': {}, 'view': {}, 'function': {}, 111 | 'datatype': {}} 112 | 113 | @property 114 | def data(self) -> dict: 115 | return self._dbmetadata 116 | -------------------------------------------------------------------------------- /odbcli/disconnect_dialog.py: -------------------------------------------------------------------------------- 1 | from prompt_toolkit.widgets import Button, Dialog, Label 2 | from prompt_toolkit.layout.containers import ConditionalContainer 3 | from prompt_toolkit.layout.dimension import Dimension as D 4 | from prompt_toolkit.filters import is_done 5 | from .conn import connWrappers 6 | from .filters import ShowDisconnectDialog 7 | 8 | def disconnect_dialog(my_app: "sqlApp"): 9 | def yes_handler() -> None: 10 | # This is not preferred since completer may currently have 11 | # a different connection attached. 12 | # my_app.completer.reset_completions() 13 | obj = my_app.selected_object 14 | obj.conn.dbmetadata.reset_metadata() 15 | obj.collapse() 16 | obj.conn.close() 17 | if my_app.active_conn is obj.conn: 18 | my_app.active_conn = None 19 | my_app.show_disconnect_dialog = False 20 | my_app.show_sidebar = True 21 | my_app.application.layout.focus("sidebarbuffer") 22 | 23 | def rc_handler() -> None: 24 | obj = my_app.selected_object 25 | obj.conn.dbmetadata.reset_metadata() 26 | my_app.show_disconnect_dialog = False 27 | my_app.show_sidebar = True 28 | my_app.application.layout.focus("sidebarbuffer") 29 | 30 | def no_handler() -> None: 31 | my_app.show_disconnect_dialog = False 32 | my_app.show_sidebar = True 33 | my_app.application.layout.focus("sidebarbuffer") 34 | 35 | dialog = Dialog( 36 | title = lambda: my_app.selected_object.name, 37 | body = Label(text = "Disconnect or Reset Completions?", 38 | dont_extend_height = True), 39 | buttons = [ 40 | Button(text = "Disconnect", handler = yes_handler), 41 | Button(text = "Reset Completions", handler = rc_handler, width = 20), 42 | Button(text = "Cancel", handler = no_handler), 43 | ], 44 | width = D(min = 10, preferred = 50), 45 | with_background = False, 46 | ) 47 | 48 | return ConditionalContainer( 49 | content = dialog, 50 | filter = ShowDisconnectDialog(my_app) & ~is_done 51 | ) 52 | -------------------------------------------------------------------------------- /odbcli/filters.py: -------------------------------------------------------------------------------- 1 | from prompt_toolkit.filters import Filter 2 | from prompt_toolkit.application import get_app 3 | from re import sub, findall 4 | 5 | class SqlAppFilter(Filter): 6 | def __init__(self, sql_app: "sqlApp") -> None: 7 | super().__init__() 8 | self.my_app = sql_app 9 | 10 | def __call__(self) -> bool: 11 | raise NotImplementedError 12 | 13 | class ShowSidebar(SqlAppFilter): 14 | def __call__(self) -> bool: 15 | return self.my_app.show_sidebar and not self.my_app.show_exit_confirmation 16 | 17 | class ShowLoginPrompt(SqlAppFilter): 18 | def __call__(self) -> bool: 19 | return self.my_app.show_login_prompt 20 | 21 | class ShowPreview(SqlAppFilter): 22 | def __call__(self) -> bool: 23 | return self.my_app.show_preview 24 | 25 | class ShowDisconnectDialog(SqlAppFilter): 26 | def __call__(self) -> bool: 27 | return self.my_app.show_disconnect_dialog 28 | 29 | class MultilineFilter(SqlAppFilter): 30 | def _is_open_quote(self, sql: str): 31 | """ To implement """ 32 | return False 33 | def _is_query_executable(self, sql: str): 34 | # A complete command is an sql statement that ends with a 'GO', unless 35 | # there's an open quote surrounding it, as is common when writing a 36 | # CREATE FUNCTION command 37 | if sql is not None and sql != "": 38 | # remove comments 39 | #esql = sqlparse.format(sql, strip_comments=True) 40 | # check for open comments 41 | # remove all closed quotes to isolate instances of open comments 42 | sql_no_quotes = sub(r'".*?"|\'.*?\'', '', sql) 43 | is_open_comment = len(findall(r'\/\*', sql_no_quotes)) > 0 44 | # check that 'go' is only token on newline 45 | lines = sql.split('\n') 46 | lastline = lines[len(lines) - 1].lower().strip() 47 | is_valid_go_on_lastline = lastline == 'go' 48 | # check that 'go' is on last line, not in open quotes, and there's no open 49 | # comment with closed comments and quotes removed. 50 | # NOTE: this method fails when GO follows a closing '*/' block comment on the same line, 51 | # we've taken a dependency with sqlparse 52 | # (https://github.com/andialbrecht/sqlparse/issues/484) 53 | return not self._is_open_quote(sql) and not is_open_comment and is_valid_go_on_lastline 54 | return False 55 | 56 | 57 | def _multiline_exception(self, text: str): 58 | text = text.strip() 59 | return ( 60 | text.startswith('\\') or # Special Command 61 | text.endswith(r'\e') or # Ended with \e which should launch the editor 62 | self._is_query_executable(text) or # A complete SQL command 63 | (text.endswith(';')) or # GO doesn't work everywhere 64 | (text == 'exit') or # Exit doesn't need semi-colon 65 | (text == 'quit') or # Quit doesn't need semi-colon 66 | (text == ':q') or # To all the vim fans out there 67 | (text == '') # Just a plain enter without any text 68 | ) 69 | 70 | def __call__(self) -> bool: 71 | doc = get_app().layout.get_buffer_by_name("defaultbuffer").document 72 | if not self.my_app.multiline: 73 | return False 74 | return not self._multiline_exception(doc.text) 75 | -------------------------------------------------------------------------------- /odbcli/layout.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from prompt_toolkit.layout.processors import AppendAutoSuggestion 3 | from prompt_toolkit.key_binding.vi_state import InputMode 4 | from prompt_toolkit.layout.containers import VSplit, HSplit, Window, ConditionalContainer, FloatContainer, Container, Float, ScrollOffsets 5 | from prompt_toolkit.buffer import Buffer 6 | from prompt_toolkit.layout.controls import BufferControl, FormattedTextControl 7 | from prompt_toolkit.document import Document 8 | from prompt_toolkit.filters import Condition, has_focus, is_done, renderer_height_is_known 9 | from prompt_toolkit.layout.layout import Layout 10 | from prompt_toolkit.completion import DynamicCompleter, ThreadedCompleter 11 | from prompt_toolkit.history import History, InMemoryHistory, FileHistory, ThreadedHistory 12 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory, ThreadedAutoSuggest 13 | from prompt_toolkit.layout.dimension import Dimension 14 | from prompt_toolkit.widgets import SearchToolbar 15 | from prompt_toolkit.formatted_text import StyleAndTextTuples, to_formatted_text 16 | from prompt_toolkit.layout.menus import CompletionsMenu 17 | from prompt_toolkit.application import get_app 18 | from prompt_toolkit.mouse_events import MouseEvent 19 | from prompt_toolkit.lexers import PygmentsLexer 20 | from prompt_toolkit.selection import SelectionType 21 | from pygments.lexers.sql import SqlLexer 22 | from os.path import expanduser 23 | from .sidebar import sql_sidebar, sql_sidebar_help, show_sidebar_button_info, sql_sidebar_navigation 24 | from .loginprompt import login_prompt 25 | from .disconnect_dialog import disconnect_dialog 26 | from .preview import PreviewElement 27 | from .filters import ShowLoginPrompt, ShowSidebar, MultilineFilter 28 | from .utils import if_mousedown 29 | from .conn import connStatus 30 | from .config import config_location, ensure_dir_exists 31 | 32 | def get_inputmode_fragments(my_app: "sqlApp") -> StyleAndTextTuples: 33 | """ 34 | Return current input mode as a list of (token, text) tuples for use in a 35 | toolbar. 36 | """ 37 | app = get_app() 38 | 39 | @if_mousedown 40 | def toggle_vi_mode(mouse_event: MouseEvent) -> None: 41 | my_app.vi_mode = not my_app.vi_mode 42 | 43 | token = "class:status-toolbar" 44 | input_mode_t = "class:status-toolbar.input-mode" 45 | 46 | mode = app.vi_state.input_mode 47 | result: StyleAndTextTuples = [] 48 | append = result.append 49 | 50 | # if my_app.title: 51 | if False: 52 | result.extend(to_formatted_text(my_app.title)) 53 | 54 | append((input_mode_t, "[F-4] ", toggle_vi_mode)) 55 | 56 | # InputMode 57 | if my_app.vi_mode: 58 | recording_register = app.vi_state.recording_register 59 | if recording_register: 60 | append((token, " ")) 61 | append((token + " class:record", "RECORD({})".format(recording_register))) 62 | append((token, " - ")) 63 | 64 | if app.current_buffer.selection_state is not None: 65 | if app.current_buffer.selection_state.type == SelectionType.LINES: 66 | append((input_mode_t, "Vi (VISUAL LINE)", toggle_vi_mode)) 67 | elif app.current_buffer.selection_state.type == SelectionType.CHARACTERS: 68 | append((input_mode_t, "Vi (VISUAL)", toggle_vi_mode)) 69 | append((token, " ")) 70 | elif app.current_buffer.selection_state.type == SelectionType.BLOCK: 71 | append((input_mode_t, "Vi (VISUAL BLOCK)", toggle_vi_mode)) 72 | append((token, " ")) 73 | elif mode in (InputMode.INSERT, "vi-insert-multiple"): 74 | append((input_mode_t, "Vi (INSERT)", toggle_vi_mode)) 75 | append((token, " ")) 76 | elif mode == InputMode.NAVIGATION: 77 | append((input_mode_t, "Vi (NAV)", toggle_vi_mode)) 78 | append((token, " ")) 79 | elif mode == InputMode.REPLACE: 80 | append((input_mode_t, "Vi (REPLACE)", toggle_vi_mode)) 81 | append((token, " ")) 82 | else: 83 | if app.emacs_state.is_recording: 84 | append((token, " ")) 85 | append((token + " class:record", "RECORD")) 86 | append((token, " - ")) 87 | 88 | append((input_mode_t, "Emacs", toggle_vi_mode)) 89 | append((token, " ")) 90 | 91 | append((input_mode_t, "[C-q] Exit Client", )) 92 | 93 | return result 94 | 95 | def get_connection_fragments(my_app: "sqlApp") -> StyleAndTextTuples: 96 | """ 97 | Return current input mode as a list of (token, text) tuples for use in a 98 | toolbar. 99 | """ 100 | app = get_app() 101 | status = my_app.active_conn.status if my_app.active_conn else connStatus.DISCONNECTED 102 | if status == connStatus.FETCHING: 103 | token = "class:status-toolbar.conn-fetching" 104 | status_text = "Fetching" 105 | elif status == connStatus.EXECUTING: 106 | token = "class:status-toolbar.conn-executing" 107 | status_text = "Executing" 108 | elif status == connStatus.ERROR: 109 | token = "class:status-toolbar.conn-executing" 110 | status_text = "Unexpected Error" 111 | elif status == connStatus.DISCONNECTED: 112 | token = "class:status-toolbar.conn-fetching" 113 | status_text = "Disconnected" 114 | else: 115 | token = "class:status-toolbar.conn-idle" 116 | status_text = "Idle" 117 | 118 | result: StyleAndTextTuples = [] 119 | append = result.append 120 | 121 | append((token, " " + status_text)) 122 | return result 123 | 124 | def exit_confirmation( 125 | my_app: "sqlApp", style = "class:exit-confirmation" 126 | ) -> Container: 127 | """ 128 | Create `Layout` for the exit message. 129 | """ 130 | 131 | def get_text_fragments() -> StyleAndTextTuples: 132 | # Show "Do you really want to exit?" 133 | return [ 134 | (style, "\n %s ([y]/n)" % my_app.exit_message), 135 | ("[SetCursorPosition]", ""), 136 | (style, " \n"), 137 | ] 138 | 139 | visible = ~is_done & Condition(lambda: my_app.show_exit_confirmation) 140 | 141 | return ConditionalContainer( 142 | content=Window( 143 | FormattedTextControl(get_text_fragments), style=style 144 | ), 145 | filter=visible, 146 | ) 147 | 148 | 149 | 150 | def status_bar(my_app: "sqlApp") -> Container: 151 | """ 152 | Create the `Layout` for the status bar. 153 | """ 154 | TB = "class:status-toolbar" 155 | 156 | def get_text_fragments() -> StyleAndTextTuples: 157 | 158 | result: StyleAndTextTuples = [] 159 | append = result.append 160 | 161 | append((TB, " ")) 162 | result.extend(get_inputmode_fragments(my_app)) 163 | append((TB, " ")) 164 | result.extend(get_connection_fragments(my_app)) 165 | 166 | 167 | return result 168 | 169 | return ConditionalContainer( 170 | content=Window(content=FormattedTextControl(get_text_fragments), style=TB), 171 | filter=~is_done 172 | & renderer_height_is_known 173 | & Condition( 174 | lambda: not my_app.show_exit_confirmation 175 | ), 176 | ) 177 | 178 | def sql_line_prefix( 179 | line_number: int, 180 | wrap_count: int, 181 | my_app: "sqlApp" 182 | ) -> StyleAndTextTuples: 183 | if my_app.active_conn is not None: 184 | sqlConn = my_app.active_conn 185 | prompt = sqlConn.username + "@" + sqlConn.dsn + ":" + sqlConn.current_catalog() + " > " 186 | else: 187 | prompt = "> " 188 | if line_number == 0 and wrap_count == 0: 189 | return to_formatted_text([("class:prompt", prompt)]) 190 | prompt_width = len(prompt) 191 | return [("class:prompt.dots", "." * (prompt_width - 1) + " ")] 192 | 193 | class sqlAppLayout: 194 | def __init__( 195 | self, 196 | my_app: "sqlApp" 197 | ) -> None: 198 | 199 | self.my_app = my_app 200 | self.search_field = SearchToolbar() 201 | history_file = config_location() + 'history' 202 | ensure_dir_exists(history_file) 203 | hist = ThreadedHistory(FileHistory(expanduser(history_file))) 204 | self.input_buffer = Buffer( 205 | name = "defaultbuffer", 206 | tempfile_suffix = ".py", 207 | multiline = MultilineFilter(self.my_app), 208 | history = hist, 209 | completer = ThreadedCompleter(self.my_app.completer), 210 | auto_suggest = ThreadedAutoSuggest(AutoSuggestFromHistory()), 211 | complete_while_typing = Condition( 212 | lambda: self.my_app.active_conn is not None 213 | ) 214 | ) 215 | main_win_control = BufferControl( 216 | buffer = self.input_buffer, 217 | lexer = PygmentsLexer(SqlLexer), 218 | search_buffer_control = self.search_field.control, 219 | include_default_input_processors = False, 220 | input_processors = [AppendAutoSuggestion()], 221 | preview_search = True 222 | ) 223 | 224 | self.main_win = Window( 225 | main_win_control, 226 | height = ( 227 | lambda: ( 228 | None 229 | if get_app().is_done 230 | else (Dimension(min = self.my_app.min_num_menu_lines) if not self.my_app.show_preview else Dimension(min = self.my_app.min_num_menu_lines, preferred = 180)) 231 | ) 232 | ), 233 | get_line_prefix = partial(sql_line_prefix, my_app = self.my_app), 234 | scroll_offsets=ScrollOffsets(bottom = 1, left = 4, right = 4) 235 | ) 236 | 237 | preview_element = PreviewElement(self.my_app) 238 | self.lprompt = login_prompt(self.my_app) 239 | self.preview = preview_element.create_container() 240 | self.disconnect_dialog = disconnect_dialog(self.my_app) 241 | container = HSplit([ 242 | VSplit([ 243 | FloatContainer( 244 | content=HSplit( 245 | [ 246 | self.main_win, 247 | self.search_field, 248 | ] 249 | ), 250 | floats=[ 251 | Float( 252 | bottom = 1, 253 | left = 1, 254 | right = 0, 255 | content = sql_sidebar_help(self.my_app), 256 | ), 257 | Float( 258 | content = self.lprompt 259 | ), 260 | Float( 261 | content = self.preview, 262 | ), 263 | preview_element.create_completion_float(), 264 | Float( 265 | content = self.disconnect_dialog, 266 | ), 267 | Float( 268 | left = 2, 269 | bottom = 1, 270 | content = exit_confirmation(self.my_app) 271 | ), 272 | Float( 273 | xcursor = True, 274 | ycursor = True, 275 | transparent = True, 276 | content = CompletionsMenu( 277 | scroll_offset = 1, 278 | max_height = 16, 279 | extra_filter = has_focus(self.input_buffer) 280 | ) 281 | ) 282 | ] 283 | ), 284 | ConditionalContainer( 285 | content = sql_sidebar(self.my_app), 286 | filter=ShowSidebar(self.my_app) & ~is_done, 287 | ) 288 | ]), 289 | VSplit( 290 | [status_bar(self.my_app), show_sidebar_button_info(self.my_app)] 291 | ) 292 | ]) 293 | 294 | def accept(buff): 295 | app = get_app() 296 | app.exit(result = ["non-preview", buff.text]) 297 | app.pre_run_callables.append(buff.reset) 298 | return True 299 | 300 | self.input_buffer.accept_handler = accept 301 | self.layout = Layout(container, focused_element = self.main_win) 302 | -------------------------------------------------------------------------------- /odbcli/loginprompt.py: -------------------------------------------------------------------------------- 1 | from prompt_toolkit.buffer import Buffer 2 | from prompt_toolkit.layout.dimension import Dimension as D 3 | from prompt_toolkit.widgets import Button, Dialog, Label 4 | from prompt_toolkit.widgets import TextArea 5 | from prompt_toolkit.layout.containers import HSplit, ConditionalContainer, WindowAlign, Window 6 | from prompt_toolkit.filters import is_done 7 | from cyanodbc import ConnectError, DatabaseError, SQLGetInfo 8 | from .conn import connWrappers, sqlConnection 9 | from .filters import ShowLoginPrompt 10 | 11 | def login_prompt(my_app: "sqlApp"): 12 | 13 | def ok_handler() -> None: 14 | my_app.application.layout.focus(uidTextfield) 15 | obj = my_app.selected_object 16 | try: 17 | obj.conn.connect(username = uidTextfield.text, password = pwdTextfield.text) 18 | # Query the type of back-end and instantiate an appropriate class 19 | dbms = obj.conn.get_info(SQLGetInfo.SQL_DBMS_NAME) 20 | # Now clone object 21 | cls = connWrappers[dbms] if dbms in connWrappers.keys() else sqlConnection 22 | newConn = cls( 23 | dsn = obj.conn.dsn, 24 | conn = obj.conn.conn, 25 | username = obj.conn.username, 26 | password = obj.conn.password) 27 | obj.conn.close() 28 | newConn.connect() 29 | obj.conn = newConn 30 | my_app.active_conn = obj.conn 31 | # OG some thread locking may be needed here 32 | obj.expand() 33 | except ConnectError as e: 34 | msgLabel.text = "Connect failed" 35 | else: 36 | msgLabel.text = "" 37 | my_app.show_login_prompt = False 38 | my_app.show_sidebar = True 39 | my_app.application.layout.focus("sidebarbuffer") 40 | 41 | uidTextfield.text = "" 42 | pwdTextfield.text = "" 43 | 44 | def cancel_handler() -> None: 45 | msgLabel.text = "" 46 | my_app.application.layout.focus(uidTextfield) 47 | my_app.show_login_prompt = False 48 | my_app.show_sidebar = True 49 | my_app.application.layout.focus("sidebarbuffer") 50 | 51 | def accept(buf: Buffer) -> bool: 52 | my_app.application.layout.focus(ok_button) 53 | return True 54 | 55 | ok_button = Button(text="OK", handler=ok_handler) 56 | cancel_button = Button(text="Cancel", handler=cancel_handler) 57 | 58 | pwdTextfield = TextArea( 59 | multiline=False, password=True, accept_handler=accept 60 | ) 61 | uidTextfield = TextArea( 62 | multiline=False, password=False, accept_handler=accept 63 | ) 64 | msgLabel = Label(text = "", dont_extend_height = True, style = "class:frame.label") 65 | msgLabel.window.align = WindowAlign.CENTER 66 | dialog = Dialog( 67 | title="Server Credentials", 68 | body=HSplit( 69 | [ 70 | Label(text="Username ", dont_extend_height=True), 71 | uidTextfield, 72 | Label(text="Password", dont_extend_height=True), 73 | pwdTextfield, 74 | msgLabel 75 | ], 76 | padding=D(preferred=1, max=1), 77 | ), 78 | width = D(min = 10, preferred = 50), 79 | buttons=[ok_button, cancel_button], 80 | with_background = False 81 | ) 82 | 83 | return ConditionalContainer( 84 | content = dialog, 85 | filter = ShowLoginPrompt(my_app) & ~is_done 86 | ) 87 | -------------------------------------------------------------------------------- /odbcli/odbclirc: -------------------------------------------------------------------------------- 1 | # vi: ft=dosini 2 | [main] 3 | 4 | # Multi-line mode allows breaking up the sql statements into multiple lines. If 5 | # this is set to True, then the end of the statements must have a semi-colon. 6 | # If this is set to False then sql statements can't be split into multiple 7 | # lines. End of line (return) is considered as the end of the statement. 8 | multi_line = False 9 | 10 | # log_file location. 11 | # In Unix/Linux: ~/.config/odbcli/log 12 | # In Windows: %USERPROFILE%\AppData\Local\dbcli\odbcli\log 13 | # %USERPROFILE% is typically C:\Users\{username} 14 | log_file = default 15 | 16 | # history_file location. 17 | # In Unix/Linux: ~/.config/odbcli/history 18 | # In Windows: %USERPROFILE%\AppData\Local\dbcli\odbcli\history 19 | # %USERPROFILE% is typically C:\Users\{username} 20 | history_file = default 21 | 22 | # Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO" 23 | # and "DEBUG". "NONE" disables logging. 24 | log_level = INFO 25 | 26 | # Mouse support. Limited mouse support for positioning the cursor, for 27 | # scrolling, and for clicking in the autocompletions menu. Recommended 28 | # False, as it may cause surprising behavior when trying to use the mouse. 29 | mouse_support = False 30 | 31 | # Default pager. 32 | # By default 'PAGER' environment variable is used 33 | # pager = less -SRXF 34 | 35 | # Number of lines to rserve for pager non-content output 36 | # When table format is psql, for less this is 1, pypager 2. Internally we use 37 | # this number to fetch just enough rows of data so that we can fit the table 38 | # columns at the top of each page - therefore this number also depends on the 39 | # format used 40 | pager_reserve_lines = 1 41 | 42 | # Timing of sql statement execution. 43 | timing = True 44 | 45 | # Table format. Possible values: psql, plain, simple, grid, fancy_grid, pipe, 46 | # ascii, double, github, orgtbl, rst, mediawiki, html, latex, latex_booktabs, 47 | # textile, moinmoin, jira, vertical, tsv, csv. 48 | # Recommended: psql, fancy_grid and grid. 49 | table_format = psql 50 | 51 | # Syntax Style. Possible values: manni, igor, xcode, vim, autumn, vs, rrt, 52 | # native, perldoc, borland, tango, emacs, friendly, monokai, paraiso-dark, 53 | # colorful, murphy, bw, pastie, paraiso-light, trac, default, fruity 54 | syntax_style = default 55 | 56 | # Keybindings: 57 | # When Vi mode is enabled you can use modal editing features offered by Vi in the REPL. 58 | # When Vi mode is disabled emacs keybindings such as Ctrl-A for home and Ctrl-E 59 | # for end are available in the REPL. 60 | vi = False 61 | 62 | # Number of lines to reserve for the suggestion menu 63 | min_num_menu_lines = 5 64 | 65 | 66 | # When viewing the results of a query in the main execution prompt of the client 67 | # this option determines how many (expressed as a multiplier of the number of 68 | # rows on the screen) results to pre-fetch before displaying the results screen. 69 | # The fetch operation is asynchronous / in a background process; this option 70 | # essentially determines the trade-off between smooth scrolling through the 71 | # results (higher multiplier) and how quickly the initial results are displayed 72 | # on the screen (lower multiplier). Generally a number >= 3 is recommended. 73 | fetch_chunk_multiplier = 5 74 | 75 | # When previewing a table we SELECT *. If preview_limit_rows is > 0 76 | # we attempt to limit the maximum number of rows fetched to this number. 77 | preview_limit_rows = 500 78 | 79 | # When previewing a table, the asynchronous operation will fetch all the 80 | # results from the SELECT query in chunks of fixed sizes asynchronously. 81 | # This number should generally be larger than the number of rows on the screen 82 | # but comfortably less than the total number of rows limited by 83 | # preview_limit_rows. More likely than not, you should not need to change this 84 | # number. 85 | preview_fetch_chunk_size = 100 86 | 87 | # Custom colors for the completion menu, toolbar, etc. 88 | [colors] 89 | completion-menu.completion.current = 'bg:#ffffff #000000' 90 | completion-menu.completion = 'bg:#008888 #ffffff' 91 | completion-menu.meta.completion.current = 'bg:#44aaaa #000000' 92 | completion-menu.meta.completion = 'bg:#448888 #ffffff' 93 | completion-menu.multi-column-meta = 'bg:#aaffff #000000' 94 | scrollbar.arrow = 'bg:#003333' 95 | scrollbar = 'bg:#00aaaa' 96 | selected = '#ffffff bg:#6666aa' 97 | search = '#ffffff bg:#4444aa' 98 | search.current = '#ffffff bg:#44aa44' 99 | bottom-toolbar = 'bg:#222222 #aaaaaa' 100 | bottom-toolbar.off = 'bg:#222222 #888888' 101 | bottom-toolbar.on = 'bg:#222222 #ffffff' 102 | search-toolbar = 'noinherit bold' 103 | search-toolbar.text = 'nobold' 104 | system-toolbar = 'noinherit bold' 105 | arg-toolbar = 'noinherit bold' 106 | arg-toolbar.text = 'nobold' 107 | bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold' 108 | bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold' 109 | literal.string = '#ba2121' 110 | literal.number = '#666666' 111 | keyword = 'bold #008000' 112 | 113 | preview-output-field = 'bg:#000044 #ffffff' 114 | preview-input-field = 'bg:#000066 #ffffff' 115 | preview-divider-line = '#000000' 116 | status-toolbar = "bg:#222222 #aaaaaa" 117 | status-toolbar.title = "underline" 118 | status-toolbar.inputmode = "bg:#222222 #ffffaa" 119 | status-toolbar.key = "bg:#000000 #888888" 120 | status-toolbar.pastemodeon = "bg:#aa4444 #ffffff" 121 | status-toolbar.cli-version = "bg:#222222 #ffffff bold" 122 | status-toolbar paste-mode-on = "bg:#aa4444 #ffffff" 123 | record = "bg:#884444 white" 124 | status-toolbar.input-mode = "#ffff44" 125 | status-toolbar.conn-executing = "bg:red #ffff44" 126 | status-toolbar.conn-fetching = "bg:yellow black" 127 | status-toolbar.conn-idle = "bg:#668866 #ffffff" 128 | # The options sidebar. 129 | sidebar = "bg:#bbbbbb #000000" 130 | sidebar.title = "bg:#668866 fg:#ffffff" 131 | sidebar.label = "bg:#bbbbbb fg:#222222" 132 | sidebar.status = "bg:#dddddd #000011" 133 | sidebar.label selected = "bg:#222222 #eeeeee bold" 134 | sidebar.status selected = "bg:#444444 #ffffff bold" 135 | sidebar.label active = "bg:#668866 #ffffff" 136 | sidebar.status active = "bg:#88AA88 #ffffff" 137 | sidebar.separator = "underline" 138 | sidebar.navigation.key = "bg:#bbddbb #000000 bold" 139 | sidebar.navigation.description = "bg:#dddddd #000011" 140 | sidebar.navigation = "bg:#dddddd" 141 | sidebar.helptext = "bg:#fdf6e3 #000011" 142 | 143 | # Exit confirmation 144 | exit-confirmation = "bg:#884444 #ffffff" 145 | 146 | # style classes for colored table output 147 | output.header = "#00ff5f bold" 148 | output.odd-row = "" 149 | output.even-row = "" 150 | output.null = "#808080" 151 | -------------------------------------------------------------------------------- /odbcli/odbcstyle.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pygments.styles 4 | from pygments.token import string_to_tokentype, Token 5 | from pygments.style import Style as PygmentsStyle 6 | from pygments.util import ClassNotFound 7 | from prompt_toolkit.styles.pygments import style_from_pygments_cls 8 | from prompt_toolkit.styles import merge_styles, Style 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | # map Pygments tokens (ptk 1.0) to class names (ptk 2.0). 13 | TOKEN_TO_PROMPT_STYLE = { 14 | Token.Menu.Completions.Completion.Current: "completion-menu.completion.current", 15 | Token.Menu.Completions.Completion: "completion-menu.completion", 16 | Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current", 17 | Token.Menu.Completions.Meta: "completion-menu.meta.completion", 18 | Token.Menu.Completions.MultiColumnMeta: "completion-menu.multi-column-meta", 19 | Token.Menu.Completions.ProgressButton: "scrollbar.arrow", # best guess 20 | Token.Menu.Completions.ProgressBar: "scrollbar", # best guess 21 | Token.SelectedText: "selected", 22 | Token.SearchMatch: "search", 23 | Token.SearchMatch.Current: "search.current", 24 | Token.Toolbar: "bottom-toolbar", 25 | Token.Toolbar.Off: "bottom-toolbar.off", 26 | Token.Toolbar.On: "bottom-toolbar.on", 27 | Token.Toolbar.Search: "search-toolbar", 28 | Token.Toolbar.Search.Text: "search-toolbar.text", 29 | Token.Toolbar.System: "system-toolbar", 30 | Token.Toolbar.Arg: "arg-toolbar", 31 | Token.Toolbar.Arg.Text: "arg-toolbar.text", 32 | Token.Toolbar.Transaction.Valid: "bottom-toolbar.transaction.valid", 33 | Token.Toolbar.Transaction.Failed: "bottom-toolbar.transaction.failed", 34 | Token.Output.Header: "output.header", 35 | Token.Output.OddRow: "output.odd-row", 36 | Token.Output.EvenRow: "output.even-row", 37 | Token.Output.Null: "output.null", 38 | Token.Literal.String: "literal.string", 39 | Token.Literal.Number: "literal.number", 40 | Token.Keyword: "keyword", 41 | } 42 | 43 | # reverse dict for cli_helpers, because they still expect Pygments tokens. 44 | PROMPT_STYLE_TO_TOKEN = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()} 45 | 46 | 47 | def parse_pygments_style(token_name, style_object, style_dict): 48 | """Parse token type and style string. 49 | 50 | :param token_name: str name of Pygments token. Example: "Token.String" 51 | :param style_object: pygments.style.Style instance to use as base 52 | :param style_dict: dict of token names and their styles, customized to this cli 53 | 54 | """ 55 | token_type = string_to_tokentype(token_name) 56 | try: 57 | other_token_type = string_to_tokentype(style_dict[token_name]) 58 | return token_type, style_object.styles[other_token_type] 59 | except AttributeError: 60 | return token_type, style_dict[token_name] 61 | 62 | 63 | def style_factory(name, cli_style): 64 | try: 65 | style = pygments.styles.get_style_by_name(name) 66 | except ClassNotFound: 67 | style = pygments.styles.get_style_by_name("native") 68 | 69 | prompt_styles = [] 70 | # prompt-toolkit used pygments tokens for styling before, switched to style 71 | # names in 2.0. Convert old token types to new style names, for backwards compatibility. 72 | for token in cli_style: 73 | if token.startswith("Token."): 74 | # treat as pygments token (1.0) 75 | token_type, style_value = parse_pygments_style(token, style, cli_style) 76 | if token_type in TOKEN_TO_PROMPT_STYLE: 77 | prompt_style = TOKEN_TO_PROMPT_STYLE[token_type] 78 | prompt_styles.append((prompt_style, style_value)) 79 | else: 80 | # we don't want to support tokens anymore 81 | logger.error("Unhandled style / class name: %s", token) 82 | else: 83 | # treat as prompt style name (2.0). See default style names here: 84 | # https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py 85 | prompt_styles.append((token, cli_style[token])) 86 | 87 | override_style = Style([("bottom-toolbar", "noreverse")]) 88 | return merge_styles( 89 | [style_from_pygments_cls(style), override_style, Style(prompt_styles)] 90 | ) 91 | 92 | 93 | def style_factory_output(name, cli_style): 94 | try: 95 | style = pygments.styles.get_style_by_name(name).styles 96 | except ClassNotFound: 97 | style = pygments.styles.get_style_by_name("native").styles 98 | 99 | for token in cli_style: 100 | if token.startswith("Token."): 101 | token_type, style_value = parse_pygments_style(token, style, cli_style) 102 | style.update({token_type: style_value}) 103 | elif token in PROMPT_STYLE_TO_TOKEN: 104 | token_type = PROMPT_STYLE_TO_TOKEN[token] 105 | style.update({token_type: cli_style[token]}) 106 | else: 107 | # TODO: cli helpers will have to switch to ptk.Style 108 | logger.error("Unhandled style / class name: %s", token) 109 | 110 | class OutputStyle(PygmentsStyle): 111 | default_style = "" 112 | styles = style 113 | 114 | return OutputStyle 115 | -------------------------------------------------------------------------------- /odbcli/preview.py: -------------------------------------------------------------------------------- 1 | from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightIncrementalSearchProcessor, HighlightSelectionProcessor, AppendAutoSuggestion 2 | from prompt_toolkit.buffer import Buffer 3 | from prompt_toolkit.layout.controls import BufferControl 4 | from prompt_toolkit.document import Document 5 | from prompt_toolkit.layout.menus import CompletionsMenu 6 | from prompt_toolkit.filters import has_focus, is_done 7 | from prompt_toolkit.layout.dimension import Dimension as D 8 | from prompt_toolkit.widgets import Button, TextArea, SearchToolbar, Box, Shadow, Frame 9 | from prompt_toolkit.layout.containers import Window, VSplit, HSplit, ConditionalContainer, FloatContainer, Float 10 | from prompt_toolkit.filters import Condition, is_done 11 | from prompt_toolkit.completion import Completer, Completion, ThreadedCompleter, CompleteEvent 12 | from prompt_toolkit.history import History, FileHistory, ThreadedHistory 13 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory, ThreadedAutoSuggest, AutoSuggest, Suggestion 14 | from cyanodbc import ConnectError, DatabaseError 15 | from cli_helpers.tabular_output import TabularOutputFormatter 16 | from functools import partial 17 | from typing import Callable, Iterable, List, Optional 18 | from logging import getLogger 19 | from os.path import expanduser 20 | from .completion.mssqlcompleter import MssqlCompleter 21 | from .filters import ShowPreview 22 | from .conn import connWrappers, connStatus, executionStatus 23 | from .config import config_location, ensure_dir_exists 24 | 25 | def object_to_identifier(obj: "myDBObject") -> str: 26 | # TODO: Verify connected 27 | sql_conn = obj.conn 28 | catalog = None 29 | schema = None 30 | if obj.parent is not None: 31 | if type(obj.parent).__name__ == "myDBSchema": 32 | schema = obj.parent.name 33 | elif type(obj.parent).__name__ == "myDBCatalog": 34 | catalog = obj.parent.name 35 | if obj.parent.parent is not None: 36 | if type(obj.parent.parent).__name__ == "myDBCatalog": 37 | catalog = obj.parent.parent.name 38 | 39 | if catalog: 40 | catalog = (sql_conn.quotechar + "%s" + sql_conn.quotechar) % catalog 41 | if schema: 42 | schema = (sql_conn.quotechar + "%s" + sql_conn.quotechar) % schema 43 | name = (sql_conn.quotechar + "%s" + sql_conn.quotechar) % obj.name 44 | identifier = ".".join(list(filter(None, [catalog, schema, name]))) 45 | 46 | return identifier 47 | 48 | 49 | class PreviewCompleter(Completer): 50 | """ Wraps prompt_toolkit.Completer. The buffer that this completer is 51 | attached to only carries part of of the query: for example 'WHERE ...'. 52 | To complete the query effectively we need the complete preview query 53 | and this completer constructs a document object that carries the full 54 | query and feeds it to the wrapped completer. 55 | Rather than wrapping, probably should extend the class - however 56 | at this time as the completer class is fairly hacked up and not 57 | in a steady state, let's stay with the wrap. 58 | """ 59 | def __init__(self, my_app: "sqlApp", completer: Completer) -> None: 60 | self.completer = completer 61 | self.my_app = my_app 62 | 63 | def get_completions( 64 | self, document: Document, complete_event: CompleteEvent 65 | ) -> Iterable[Completion]: 66 | obj = self.my_app.selected_object 67 | sql_conn = obj.conn 68 | identifier = object_to_identifier(obj) 69 | query = sql_conn.preview_query( 70 | name = identifier, 71 | obj_type = obj.otype, 72 | filter_query = document.text, 73 | limit = self.my_app.preview_limit_rows) 74 | if query is None: 75 | return [] 76 | 77 | new_document = Document(text = query, 78 | cursor_position = query.find(document.text) + document.cursor_position) 79 | return self.completer.get_completions(new_document, complete_event) 80 | 81 | class PreviewHistory(FileHistory): 82 | def __init__(self, filename: str, my_app: "sqlApp") -> None: 83 | self.my_app = my_app 84 | super().__init__(filename) 85 | 86 | def store_string(self, string: str) -> None: 87 | """ Store filtering query in history file by 88 | adding the "[identifier]: " prefix 89 | """ 90 | obj = self.my_app.selected_object 91 | identifier = object_to_identifier(obj) 92 | super().store_string(identifier + ": " + string) 93 | 94 | class PreviewSuggestFromHistory(AutoSuggest): 95 | """ 96 | Give suggestions based on the lines in the history. 97 | """ 98 | def __init__(self, my_app: "sqlApp") -> None: 99 | self.my_app = my_app 100 | super().__init__ 101 | 102 | def get_suggestion( 103 | self, buffer: "Buffer", document: Document 104 | ) -> Optional[Suggestion]: 105 | """ 106 | When looking for most recent suggestion look for one 107 | starting with the "[identifier]: " prefix 108 | """ 109 | history = buffer.history 110 | 111 | # Consider only the last line for the suggestion. 112 | text = document.text.rsplit("\n", 1)[-1] 113 | # Only create a suggestion when this is not an empty line. 114 | if text.strip(): 115 | obj = self.my_app.selected_object 116 | prefix = object_to_identifier(obj) + ": " 117 | # Find first matching line in history. 118 | for string in reversed(list(history.get_strings())): 119 | for line in reversed(string.splitlines()): 120 | loc = line.find(prefix) 121 | # Add one character for a space after SELECT identifier 122 | if loc >= 0 and line[loc + len(prefix):].startswith(text): 123 | return Suggestion(line[loc + len(prefix) + len(text) :]) 124 | 125 | return None 126 | 127 | class PreviewBuffer(Buffer): 128 | def history_forward(self, count: int = 1) -> None: 129 | """ Disable searching through history on up/down arrow """ 130 | return None 131 | def history_backward(self, count: int = 1) -> None: 132 | """ Disable searching through history on up/down arrow """ 133 | return None 134 | 135 | 136 | class PreviewElement: 137 | """ Class to create the preview element. It contains two main methods: 138 | create_container: creates the main preview container. Intention is 139 | for this to land in a float. 140 | create_completion_float: creates the completion float in the preview 141 | container. Intention is for this to appear in the FloatContainer that 142 | hosts the main preview container float. 143 | """ 144 | def __init__(self, my_app: "sqlApp"): 145 | self.my_app = my_app 146 | help_text = """ 147 | Press Enter in the input box to page through the table. 148 | Alternatively, enter a filtering SQL statement and then press Enter 149 | to page through the results. 150 | """ 151 | self.formatter = TabularOutputFormatter() 152 | self.completer = PreviewCompleter( 153 | my_app = self.my_app, 154 | completer = MssqlCompleter( 155 | smart_completion = True, 156 | get_conn = lambda: self.my_app.selected_object.conn)) 157 | 158 | history_file = config_location() + 'preview_history' 159 | ensure_dir_exists(history_file) 160 | hist = PreviewHistory( 161 | my_app = self.my_app, 162 | filename = expanduser(history_file)) 163 | 164 | self.input_buffer = PreviewBuffer( 165 | name = "previewbuffer", 166 | tempfile_suffix = ".sql", 167 | history = ThreadedHistory(hist), 168 | auto_suggest = 169 | ThreadedAutoSuggest(PreviewSuggestFromHistory(my_app)), 170 | completer = ThreadedCompleter(self.completer), 171 | # history = hist, 172 | # auto_suggest = PreviewSuggestFromHistory(my_app), 173 | # completer = self.completer, 174 | complete_while_typing = Condition( 175 | lambda: self.my_app.selected_object is not None and self.my_app.selected_object.conn.connected() 176 | ), 177 | multiline = False) 178 | 179 | input_control = BufferControl( 180 | buffer = self.input_buffer, 181 | include_default_input_processors = False, 182 | input_processors = [AppendAutoSuggestion()], 183 | preview_search = False) 184 | 185 | self.input_window = Window(input_control) 186 | 187 | search_buffer = Buffer(name = "previewsearchbuffer") 188 | self.search_field = SearchToolbar(search_buffer) 189 | self.output_field = TextArea(style = "class:preview-output-field", 190 | text = help_text, 191 | height = D(preferred = 50), 192 | search_field = self.search_field, 193 | wrap_lines = False, 194 | focusable = True, 195 | read_only = True, 196 | preview_search = True, 197 | input_processors = [ 198 | ConditionalProcessor( 199 | processor=HighlightIncrementalSearchProcessor(), 200 | filter=has_focus("previewsearchbuffer") 201 | | has_focus(self.search_field.control), 202 | ), 203 | HighlightSelectionProcessor(), 204 | ]) 205 | 206 | def refresh_results(window_height) -> bool: 207 | """ This method gets called when the app restarts after 208 | exiting for execution of preview query. It populates 209 | the output buffer with results from the fetch/query. 210 | """ 211 | sql_conn = self.my_app.selected_object.conn 212 | if sql_conn.execution_status == executionStatus.FAIL: 213 | # Let's display the error message to the user 214 | output = sql_conn.execution_err 215 | else: 216 | crsr = sql_conn.cursor 217 | if crsr.description: 218 | cols = [col.name for col in crsr.description] 219 | else: 220 | cols = [] 221 | if len(cols): 222 | res = sql_conn.fetch_from_cache(size = window_height - 4, 223 | wait = True) 224 | output = self.formatter.format_output(res, cols, format_name = "psql") 225 | output = "\n".join(output) 226 | else: 227 | output = "No rows returned\n" 228 | 229 | # Add text to output buffer. 230 | self.output_field.buffer.set_document(Document( 231 | text = output, cursor_position = 0), True) 232 | 233 | return True 234 | 235 | def accept(buff: Buffer) -> bool: 236 | """ This method gets called when the user presses enter/return 237 | in the filter box. It is interpreted as either 'execute query' 238 | or 'fetch next page of results' if filter query hasn't changed. 239 | """ 240 | obj = self.my_app.selected_object 241 | sql_conn = obj.conn 242 | identifier = object_to_identifier(obj) 243 | query = sql_conn.preview_query( 244 | name = identifier, 245 | obj_type = obj.otype, 246 | filter_query = buff.text, 247 | limit = self.my_app.preview_limit_rows) 248 | if query is None: 249 | return True 250 | 251 | func = partial(refresh_results, 252 | window_height = self.output_field.window.render_info.window_height) 253 | if sql_conn.query != query: 254 | # Exit the app to execute the query 255 | self.my_app.application.exit(result = ["preview", query]) 256 | self.my_app.application.pre_run_callables.append(func) 257 | else: 258 | # No need to exit let's just go and fetch 259 | func() 260 | return True # Keep filter text 261 | 262 | def cancel_handler() -> None: 263 | sql_conn = self.my_app.selected_object.conn 264 | sql_conn.close_cursor() 265 | self.input_buffer.text = "" 266 | self.output_field.buffer.set_document(Document( 267 | text = help_text, cursor_position = 0 268 | ), True) 269 | self.my_app.show_preview = False 270 | self.my_app.show_sidebar = True 271 | self.my_app.application.layout.focus(self.input_buffer) 272 | self.my_app.application.layout.focus("sidebarbuffer") 273 | return None 274 | 275 | self.input_buffer.accept_handler = accept 276 | self.cancel_button = Button(text = "Done", handler = cancel_handler) 277 | 278 | def create_completion_float(self) -> Float: 279 | return Float( 280 | xcursor = True, 281 | ycursor = True, 282 | transparent = True, 283 | attach_to_window = self.input_window, 284 | content = CompletionsMenu( 285 | scroll_offset = 1, 286 | max_height = 16, 287 | extra_filter = has_focus(self.input_buffer))) 288 | 289 | def create_container(self): 290 | 291 | container = HSplit( 292 | [ 293 | Box( 294 | body = VSplit( 295 | [self.input_window, self.cancel_button], 296 | padding=1 297 | ), 298 | padding=1, 299 | style="class:preview-input-field" 300 | ), 301 | Window(height=1, char="-", style="class:preview-divider-line"), 302 | self.output_field, 303 | self.search_field, 304 | ]) 305 | 306 | frame = Shadow( 307 | body = Frame( 308 | title = lambda: "Preview: " + self.my_app.selected_object.name, 309 | body = container, 310 | style="class:dialog.body", 311 | width = D(preferred = 180, min = 30), 312 | modal = True)) 313 | 314 | return ConditionalContainer( 315 | content = frame, 316 | filter = ShowPreview(self.my_app) & ~is_done) 317 | -------------------------------------------------------------------------------- /odbcli/sidebar.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import platform 3 | from cyanodbc import Connection 4 | from typing import List, Optional, Callable 5 | from logging import getLogger 6 | from asyncio import get_event_loop 7 | from threading import Thread, Lock 8 | from prompt_toolkit.layout.containers import HSplit, Window, ScrollOffsets, ConditionalContainer, Container 9 | from prompt_toolkit.formatted_text.base import StyleAndTextTuples 10 | from prompt_toolkit.formatted_text import fragment_list_width 11 | from prompt_toolkit.layout.controls import FormattedTextControl, BufferControl, UIContent 12 | from prompt_toolkit.layout.dimension import Dimension 13 | from prompt_toolkit.buffer import Buffer 14 | from prompt_toolkit.document import Document 15 | from prompt_toolkit.filters import is_done, renderer_height_is_known 16 | from prompt_toolkit.layout.margins import ScrollbarMargin 17 | from prompt_toolkit.mouse_events import MouseEvent 18 | from prompt_toolkit.lexers import Lexer 19 | from prompt_toolkit.widgets import SearchToolbar 20 | from prompt_toolkit.filters import Condition 21 | from .conn import sqlConnection 22 | from .filters import ShowSidebar 23 | from .utils import if_mousedown 24 | from .__init__ import __version__ 25 | 26 | class myDBObject: 27 | def __init__( 28 | self, 29 | my_app: "sqlApp", 30 | conn: sqlConnection, 31 | name: str, 32 | otype: str, 33 | level: Optional[int] = 0, 34 | children: Optional[List["myDBObject"]] = None, 35 | parent: Optional["myDBObject"] = None, 36 | next_object: Optional["myDBObject"] = None 37 | ) -> None: 38 | 39 | self.my_app = my_app 40 | self.conn = conn 41 | self.children = children 42 | self.parent = parent 43 | self.next_object = next_object 44 | # Held while modifying children, parent, next_object 45 | # As some of thes operatins (expand) happen asynchronously 46 | self._lock = Lock() 47 | self.name = name 48 | self.otype = otype 49 | self.level = level 50 | self.selected: bool = False 51 | 52 | def _expand_internal(self) -> None: 53 | """ 54 | Populates children and sets parent for children nodes 55 | """ 56 | raise NotImplementedError() 57 | 58 | def expand(self) -> None: 59 | """ 60 | Populates children and sets parent for children nodes 61 | """ 62 | if self.children is not None: 63 | return None 64 | 65 | loop = get_event_loop() 66 | self.my_app.show_expanding_object = True 67 | self.my_app.application.invalidate() 68 | def _redraw_after_io(): 69 | """ Callback, scheduled after threaded I/O 70 | completes """ 71 | self.my_app.show_expanding_object = False 72 | self.my_app.obj_list_changed = True 73 | self.my_app.application.invalidate() 74 | 75 | def _run(): 76 | """ Executes in a thread """ 77 | self._expand_internal() # Blocking I/O 78 | loop.call_soon_threadsafe(_redraw_after_io) 79 | 80 | # (Don't use 'run_in_executor', because daemon is ideal here. 81 | t = Thread(target = _run, daemon = True) 82 | t.start() 83 | 84 | def collapse(self) -> None: 85 | """ 86 | Populates children and sets parent for children nodes 87 | Note, we don't have to blow up the children; just redirect 88 | next_object. This way we re-query the database / force re-fresh 89 | which may be suboptimal. TODO: Codify not/refresh path 90 | """ 91 | if self is not self.my_app.selected_object: 92 | return 93 | if self.children is not None: 94 | obj = self.children[len(self.children) - 1].next_object 95 | while obj.level > self.level: 96 | obj = obj.next_object 97 | with self._lock: 98 | self.next_object = obj 99 | self.children = None 100 | elif self.parent is not None: 101 | self.my_app.selected_object = self.parent 102 | self.parent.collapse() 103 | 104 | self.my_app.obj_list_changed = True 105 | 106 | def add_children(self, list_obj: List["myDBObject"]) -> None: 107 | lst = list(filter(lambda x: x.name != "", list_obj)) 108 | if len(lst): 109 | with self._lock: 110 | self.children = lst 111 | for i in range(len(self.children) - 1): 112 | self.children[i].next_object = self.children[i + 1] 113 | self.children[len(self.children) - 1].next_object = self.next_object 114 | self.next_object = self.children[0] 115 | 116 | class myDBColumn(myDBObject): 117 | def _expand_internal(self) -> None: 118 | return None 119 | 120 | class myDBFunction(myDBObject): 121 | def _expand_internal(self) -> None: 122 | cat = "%" 123 | schema = "%" 124 | # https://docs.microsoft.com/en-us/sql/odbc/reference/syntax/sqlprocedurecolumns-function?view=sql-server-ver15 125 | # CatalogName cannot contain a string search pattern 126 | 127 | if self.parent is not None: 128 | if type(self.parent).__name__ == "myDBSchema": 129 | schema = self.conn.sanitize_search_string(self.parent.name) 130 | elif type(self.parent).__name__ == "myDBCatalog": 131 | cat = self.parent.name 132 | if self.parent.parent is not None: 133 | if type(self.parent.parent).__name__ == "myDBCatalog": 134 | cat = self.parent.parent.name 135 | 136 | res = self.conn.find_procedure_columns( 137 | catalog = cat, 138 | schema = schema, 139 | procedure = self.conn.sanitize_search_string(self.name), 140 | column = "%") 141 | 142 | lst = [myDBColumn( 143 | my_app = self.my_app, 144 | conn = self.conn, 145 | name = col.column, 146 | otype = col.type_name, 147 | parent = self, 148 | level = self.level + 1) for col in res] 149 | 150 | self.add_children(list_obj = lst) 151 | return None 152 | 153 | class myDBTable(myDBObject): 154 | def _expand_internal(self) -> None: 155 | cat = "%" 156 | schema = "%" 157 | # https://docs.microsoft.com/en-us/sql/odbc/reference/syntax/sqlcolumns-function?view=sql-server-ver15 158 | # CatalogName cannot contain a string search pattern 159 | 160 | if self.parent is not None: 161 | if type(self.parent).__name__ == "myDBSchema": 162 | schema = self.conn.sanitize_search_string(self.parent.name) 163 | elif type(self.parent).__name__ == "myDBCatalog": 164 | cat = self.parent.name 165 | if self.parent.parent is not None: 166 | if type(self.parent.parent).__name__ == "myDBCatalog": 167 | cat = self.parent.parent.name 168 | 169 | res = self.conn.find_columns( 170 | catalog = cat, 171 | schema = schema, 172 | table = self.name, 173 | column = "%") 174 | 175 | lst = [myDBColumn( 176 | my_app = self.my_app, 177 | conn = self.conn, 178 | name = col.column, 179 | otype = col.type_name, 180 | parent = self, 181 | level = self.level + 1) for col in res] 182 | 183 | self.add_children(list_obj = lst) 184 | return None 185 | 186 | class myDBSchema(myDBObject): 187 | def _expand_internal(self) -> None: 188 | 189 | cat = self.conn.sanitize_search_string(self.parent.name) if self.parent is not None else "%" 190 | res = self.conn.find_tables( 191 | catalog = cat, 192 | schema = self.conn.sanitize_search_string(self.name), 193 | table = "", 194 | type = "") 195 | resf = self.conn.find_procedures( 196 | catalog = cat, 197 | schema = self.conn.sanitize_search_string(self.name), 198 | procedure = "") 199 | tables = [] 200 | views = [] 201 | functions = [] 202 | lst = [] 203 | for table in res: 204 | if table.type.lower() == 'table': 205 | tables.append(table.name) 206 | if table.type.lower() == 'view': 207 | views.append(table.name) 208 | lst.append(myDBTable( 209 | my_app = self.my_app, 210 | conn = self.conn, 211 | name = table.name, 212 | otype = table.type.lower(), 213 | parent = self, 214 | level = self.level + 1)) 215 | for func in resf: 216 | functions.append(func.name) 217 | lst.append(myDBFunction( 218 | my_app = self.my_app, 219 | conn = self.conn, 220 | name = func.name, 221 | otype = "function", 222 | parent = self, 223 | level = self.level + 1)) 224 | 225 | self.conn.dbmetadata.extend_objects( 226 | catalog = self.conn.escape_name(self.parent.name) if self.parent else "", 227 | schema = self.conn.escape_name(self.name), 228 | names = self.conn.escape_names(tables), 229 | obj_type = "table") 230 | self.conn.dbmetadata.extend_objects( 231 | catalog = self.conn.escape_name(self.parent.name) if self.parent else "", 232 | schema = self.conn.escape_name(self.name), 233 | names = self.conn.escape_names(views), 234 | obj_type = "view") 235 | self.conn.dbmetadata.extend_objects( 236 | catalog = self.conn.escape_name(self.parent.name) if self.parent else "", 237 | schema = self.conn.escape_name(self.name), 238 | names = self.conn.escape_names(functions), 239 | obj_type = "function") 240 | self.add_children(list_obj = lst) 241 | return None 242 | 243 | class myDBCatalog(myDBObject): 244 | def _expand_internal(self) -> None: 245 | schemas = lst = [] 246 | schemas = self.conn.list_schemas( 247 | catalog = self.conn.sanitize_search_string(self.name)) 248 | 249 | if len(schemas) < 1 or all([s == "" for s in schemas]): 250 | res = self.conn.find_tables( 251 | catalog = self.conn.sanitize_search_string(self.name), 252 | schema = "", 253 | table = "", 254 | type = "") 255 | schemas = [r.schema for r in res] 256 | 257 | self.conn.dbmetadata.extend_schemas( 258 | catalog = self.conn.escape_name(self.name), 259 | names = self.conn.escape_names(schemas)) 260 | 261 | if not all([s == "" for s in schemas]): 262 | # Schemas were found either having called list_schemas 263 | # or via the find_tables call 264 | lst = [myDBSchema( 265 | my_app = self.my_app, 266 | conn = self.conn, 267 | name = schema, 268 | otype = "schema", 269 | parent = self, 270 | level = self.level + 1) for schema in sorted(set(schemas))] 271 | elif len(schemas): 272 | # No schemas found; but if there are tables then these are direct 273 | # descendents, i.e. MySQL 274 | tables = [] 275 | views = [] 276 | lst = [] 277 | for table in res: 278 | if table.type.lower() == 'table': 279 | tables.append(table.name) 280 | if table.type.lower() == 'view': 281 | views.append(table.name) 282 | lst.append(myDBTable( 283 | my_app = self.my_app, 284 | conn = self.conn, 285 | name = table.name, 286 | otype = table.type.lower(), 287 | parent = self, 288 | level = self.level + 1)) 289 | self.conn.dbmetadata.extend_objects( 290 | catalog = self.conn.escape_name(self.name), 291 | schema = "", names = self.conn.escape_names(tables), 292 | obj_type = "table") 293 | self.conn.dbmetadata.extend_objects( 294 | catalog = self.conn.escape_name(self.name), 295 | schema = "", names = self.conn.escape_names(views), 296 | obj_type = "view") 297 | 298 | self.add_children(list_obj = lst) 299 | return None 300 | 301 | 302 | class myDBConn(myDBObject): 303 | def _expand_internal(self) -> None: 304 | if not self.conn.connected(): 305 | return None 306 | 307 | lst = [] 308 | cat_support = self.conn.catalog_support() 309 | if cat_support: 310 | rows = self.conn.list_catalogs() 311 | if len(rows): 312 | lst = [myDBCatalog( 313 | my_app = self.my_app, 314 | conn = self.conn, 315 | name = row, 316 | otype = "catalog", 317 | parent = self, 318 | level = self.level + 1) for row in rows] 319 | self.conn.dbmetadata.extend_catalogs( 320 | self.conn.escape_names(rows)) 321 | else: 322 | res = self.conn.find_tables( 323 | catalog = "%", 324 | schema = "", 325 | table = "", 326 | type = "") 327 | schemas = [r.schema for r in res] 328 | self.conn.dbmetadata.extend_schemas(catalog = "", 329 | names = self.conn.escape_names(schemas)) 330 | if not all([s == "" for s in schemas]): 331 | lst = [myDBSchema( 332 | my_app = self.my_app, 333 | conn = self.conn, 334 | name = schema, 335 | otype = "schema", 336 | parent = self, 337 | level = self.level + 1) for schema in sorted(set(schemas))] 338 | elif len(schemas): 339 | tables = [] 340 | views = [] 341 | lst = [] 342 | for table in res: 343 | if table.type.lower() == 'table': 344 | tables.append(table.name) 345 | if table.type.lower() == 'view': 346 | views.append(table.name) 347 | lst.append(myDBTable( 348 | my_app = self.my_app, 349 | conn = self.conn, 350 | name = table.name, 351 | otype = table.type.lower(), 352 | parent = self, 353 | level = self.level + 1)) 354 | self.conn.dbmetadata.extend_objects(catalog = "", 355 | schema = "", names = self.conn.escape_names(tables), 356 | obj_type = "table") 357 | self.conn.dbmetadata.extend_objects(catalog = "", 358 | schema = "", names = self.conn.escape_names(views), 359 | obj_type = "view") 360 | self.add_children(list_obj = lst) 361 | return None 362 | 363 | def sql_sidebar_help(my_app: "sqlApp"): 364 | """ 365 | Create the `Layout` for the help text for the current item in the sidebar. 366 | """ 367 | token = "class:sidebar.helptext" 368 | 369 | def get_current_description(): 370 | """ 371 | Return the description of the selected option. 372 | """ 373 | obj = my_app.selected_object 374 | if obj is not None: 375 | return obj.name 376 | return "" 377 | 378 | def get_help_text(): 379 | return [(token, get_current_description())] 380 | 381 | return ConditionalContainer( 382 | content=Window( 383 | FormattedTextControl(get_help_text), style=token, height=Dimension(min=3) 384 | ), 385 | filter = ~is_done 386 | & ShowSidebar(my_app) 387 | & Condition( 388 | lambda: not my_app.show_exit_confirmation 389 | )) 390 | 391 | def expanding_object_notification(my_app: "sqlApp"): 392 | """ 393 | Create the `Layout` for the 'Expanding object' notification. 394 | """ 395 | 396 | def get_text_fragments(): 397 | # Show navigation info. 398 | return [("fg:red", "Expanding object ...")] 399 | 400 | return ConditionalContainer( 401 | content = Window( 402 | FormattedTextControl(get_text_fragments), 403 | style = "class:sidebar", 404 | width=Dimension.exact( 45 ), 405 | height=Dimension(max = 1), 406 | ), 407 | filter = ~is_done 408 | & ShowSidebar(my_app) 409 | & Condition( 410 | lambda: my_app.show_expanding_object 411 | )) 412 | 413 | def sql_sidebar_navigation(): 414 | """ 415 | Create the `Layout` showing the navigation information for the sidebar. 416 | """ 417 | 418 | def get_text_fragments(): 419 | # Show navigation info. 420 | return [ 421 | ("class:sidebar.navigation", " "), 422 | ("class:sidebar.navigation.key", "[Up/Dn]"), 423 | ("class:sidebar.navigation", " "), 424 | ("class:sidebar.navigation.description", "Navigate"), 425 | ("class:sidebar.navigation", " "), 426 | ("class:sidebar.navigation.key", "[L/R]"), 427 | ("class:sidebar.navigation", " "), 428 | ("class:sidebar.navigation.description", "Expand/Collapse"), 429 | ("class:sidebar.navigation", "\n "), 430 | ("class:sidebar.navigation.key", "[Enter]"), 431 | ("class:sidebar.navigation", " "), 432 | ("class:sidebar.navigation.description", "Connect/Preview"), 433 | ] 434 | 435 | return Window( 436 | FormattedTextControl(get_text_fragments), 437 | style = "class:sidebar.navigation", 438 | width=Dimension.exact( 45 ), 439 | height=Dimension(max = 2), 440 | ) 441 | 442 | def show_sidebar_button_info(my_app: "sqlApp") -> Container: 443 | """ 444 | Create `Layout` for the information in the right-bottom corner. 445 | (The right part of the status bar.) 446 | """ 447 | 448 | @if_mousedown 449 | def toggle_sidebar(mouse_event: MouseEvent) -> None: 450 | " Click handler for the menu. " 451 | my_app.show_sidebar = not my_app.show_sidebar 452 | 453 | # TO DO: app version rather than python 454 | version = sys.version_info 455 | tokens: StyleAndTextTuples = [ 456 | ("class:status-toolbar.key", "[C-t]", toggle_sidebar), 457 | ("class:status-toolbar", " Object Browser", toggle_sidebar), 458 | ("class:status-toolbar", " - "), 459 | ("class:status-toolbar.cli-version", "odbcli %s" % __version__), 460 | ("class:status-toolbar", " "), 461 | ] 462 | width = fragment_list_width(tokens) 463 | 464 | def get_text_fragments() -> StyleAndTextTuples: 465 | # Python version 466 | return tokens 467 | 468 | return ConditionalContainer( 469 | content=Window( 470 | FormattedTextControl(get_text_fragments), 471 | style="class:status-toolbar", 472 | height=Dimension.exact(1), 473 | width=Dimension.exact(width), 474 | ), 475 | filter=~is_done 476 | & Condition( 477 | lambda: not my_app.show_exit_confirmation 478 | ) 479 | & renderer_height_is_known 480 | ) 481 | 482 | def sql_sidebar(my_app: "sqlApp") -> Window: 483 | """ 484 | Create the `Layout` for the sidebar with the configurable objects. 485 | """ 486 | 487 | @if_mousedown 488 | def expand_item(obj: "myDBObject") -> None: 489 | obj.expand() 490 | 491 | def tokenize_obj(obj: "myDBObject") -> StyleAndTextTuples: 492 | " Recursively build the token list " 493 | tokens: StyleAndTextTuples = [] 494 | selected = obj is my_app.selected_object 495 | expanded = obj.children is not None 496 | connected = obj.otype == "Connection" and obj.conn.connected() 497 | active = my_app.active_conn is not None and my_app.active_conn is obj.conn and obj.level == 0 498 | 499 | act = ",active" if active else "" 500 | sel = ",selected" if selected else "" 501 | if len(obj.name) > 24 - 2 * obj.level: 502 | name_trim = obj.name[:24 - 2 * obj.level - 3] + "..." 503 | else: 504 | name_trim = ("%-" + str(24 - 2 * obj.level) + "s") % obj.name 505 | 506 | tokens.append(("class:sidebar.label" + sel + act, " >" if connected else " ")) 507 | tokens.append( 508 | ("class:sidebar.label" + sel, " " * 2 * obj.level, expand_item) 509 | ) 510 | tokens.append( 511 | ("class:sidebar.label" + sel + act, 512 | name_trim, 513 | expand_item) 514 | ) 515 | tokens.append(("class:sidebar.status" + sel + act, " ", expand_item)) 516 | tokens.append(("class:sidebar.status" + sel + act, "%+12s" % obj.otype, expand_item)) 517 | 518 | if selected: 519 | tokens.append(("[SetCursorPosition]", "")) 520 | 521 | if expanded: 522 | tokens.append(("class:sidebar.status" + sel + act, "\/")) 523 | else: 524 | tokens.append(("class:sidebar.status" + sel + act, " <" if selected else " ")) 525 | 526 | # Expand past the edge of the visible buffer to get an even panel 527 | tokens.append(("class:sidebar.status" + sel + act, " " * 10)) 528 | return tokens 529 | 530 | search_buffer = Buffer(name = "sidebarsearchbuffer") 531 | search_field = SearchToolbar( 532 | search_buffer = search_buffer, 533 | ignore_case = True 534 | ) 535 | def _buffer_pos_changed(buff): 536 | """ This callback gets executed after cursor position changes. Most 537 | of the time we register a key-press (up / down), we change the 538 | selected object and as a result of that the cursor changes. By that 539 | time we don't need to updat the selected object (cursor changed as 540 | a result of the selected object being updated). The one exception 541 | is when searching the sidebar buffer. When this happens the cursor 542 | moves ahead of the selected object. When that happens, here we 543 | update the selected object to follow suit. 544 | """ 545 | if buff.document.cursor_position_row != my_app.selected_object_idx[0]: 546 | my_app.select(buff.document.cursor_position_row) 547 | 548 | sidebar_buffer = Buffer( 549 | name = "sidebarbuffer", 550 | read_only = True, 551 | on_cursor_position_changed = _buffer_pos_changed 552 | ) 553 | 554 | class myLexer(Lexer): 555 | def __init__(self, *args, **kwargs): 556 | super().__init__(*args, **kwargs) 557 | self._obj_list = [] 558 | 559 | def add_objects(self, objects: List): 560 | self._obj_list = objects 561 | 562 | def lex_document(self, document: Document) -> Callable[[int], StyleAndTextTuples]: 563 | def get_line(lineno: int) -> StyleAndTextTuples: 564 | # TODO: raise out-of-range exception 565 | return tokenize_obj(self._obj_list[lineno]) 566 | return get_line 567 | 568 | 569 | sidebar_lexer = myLexer() 570 | 571 | class myControl(BufferControl): 572 | 573 | def move_cursor_down(self): 574 | my_app.select_next() 575 | # Need to figure out what do do here 576 | # AFAICT these are only called for the mouse handler 577 | # when events are otherwise not handled 578 | def move_cursor_up(self): 579 | my_app.select_previous() 580 | 581 | def mouse_handler(self, mouse_event: MouseEvent) -> "NotImplementedOrNone": 582 | """ 583 | There is an intricate relationship between the cursor position 584 | in the sidebar document and which object is market as 'selected' 585 | in the linked list. Let's not muck that up by allowing the user 586 | to change the cursor position in the sidebar document with the mouse. 587 | """ 588 | return NotImplemented 589 | 590 | def create_content(self, width: int, height: Optional[int]) -> UIContent: 591 | # Only traverse the obj_list if it has been expanded / collapsed 592 | if not my_app.obj_list_changed: 593 | self.buffer.cursor_position = my_app.selected_object_idx[1] 594 | return super().create_content(width, height) 595 | 596 | res = [] 597 | obj = my_app.obj_list[0] 598 | res.append(obj) 599 | while obj.next_object is not my_app.obj_list[0]: 600 | res.append(obj.next_object) 601 | obj = obj.next_object 602 | 603 | self.lexer.add_objects(res) 604 | self.buffer.set_document(Document( 605 | text = "\n".join([a.name for a in res]), cursor_position = my_app.selected_object_idx[1]), True) 606 | # Reset obj_list_changed flag, now that we have had a chance to 607 | # regenerate the sidebar document content 608 | my_app.obj_list_changed = False 609 | return super().create_content(width, height) 610 | 611 | 612 | 613 | sidebar_control = myControl( 614 | buffer = sidebar_buffer, 615 | lexer = sidebar_lexer, 616 | search_buffer_control = search_field.control, 617 | focusable = True, 618 | ) 619 | 620 | return HSplit([ 621 | search_field, 622 | Window( 623 | sidebar_control, 624 | right_margins = [ScrollbarMargin(display_arrows = True)], 625 | style = "class:sidebar", 626 | width = Dimension.exact( 45 ), 627 | height = Dimension(min = 7, preferred = 33), 628 | scroll_offsets = ScrollOffsets(top = 1, bottom = 1)), 629 | Window( 630 | height = Dimension.exact(1), 631 | char = "\u2500", 632 | style = "class:sidebar,separator", 633 | ), 634 | expanding_object_notification(my_app), 635 | sql_sidebar_navigation()]) 636 | -------------------------------------------------------------------------------- /odbcli/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | For internal use only. 3 | """ 4 | import re 5 | from typing import Callable, TypeVar, cast 6 | import os 7 | from platform import system 8 | 9 | from prompt_toolkit.mouse_events import MouseEvent, MouseEventType 10 | 11 | __all__ = [ 12 | "has_unclosed_brackets", 13 | "get_jedi_script_from_document", 14 | "document_is_multiline_python", 15 | ] 16 | 17 | def has_unclosed_brackets(text: str) -> bool: 18 | """ 19 | Starting at the end of the string. If we find an opening bracket 20 | for which we didn't had a closing one yet, return True. 21 | """ 22 | stack = [] 23 | 24 | # Ignore braces inside strings 25 | text = re.sub(r"""('[^']*'|"[^"]*")""", "", text) # XXX: handle escaped quotes.! 26 | 27 | for c in reversed(text): 28 | if c in "])}": 29 | stack.append(c) 30 | 31 | elif c in "[({": 32 | if stack: 33 | if ( 34 | (c == "[" and stack[-1] == "]") 35 | or (c == "{" and stack[-1] == "}") 36 | or (c == "(" and stack[-1] == ")") 37 | ): 38 | stack.pop() 39 | else: 40 | # Opening bracket for which we didn't had a closing one. 41 | return True 42 | 43 | return False 44 | 45 | 46 | def get_jedi_script_from_document(document, locals, globals): 47 | import jedi # We keep this import in-line, to improve start-up time. 48 | 49 | # Importing Jedi is 'slow'. 50 | 51 | try: 52 | return jedi.Interpreter( 53 | document.text, 54 | column=document.cursor_position_col, 55 | line=document.cursor_position_row + 1, 56 | path="input-text", 57 | namespaces=[locals, globals], 58 | ) 59 | except ValueError: 60 | # Invalid cursor position. 61 | # ValueError('`column` parameter is not in a valid range.') 62 | return None 63 | except AttributeError: 64 | # Workaround for #65: https://github.com/jonathanslenders/python-prompt-toolkit/issues/65 65 | # See also: https://github.com/davidhalter/jedi/issues/508 66 | return None 67 | except IndexError: 68 | # Workaround Jedi issue #514: for https://github.com/davidhalter/jedi/issues/514 69 | return None 70 | except KeyError: 71 | # Workaroud for a crash when the input is "u'", the start of a unicode string. 72 | return None 73 | except Exception: 74 | # Workaround for: https://github.com/jonathanslenders/ptpython/issues/91 75 | return None 76 | 77 | 78 | _multiline_string_delims = re.compile("""[']{3}|["]{3}""") 79 | 80 | 81 | def document_is_multiline_python(document): 82 | """ 83 | Determine whether this is a multiline Python document. 84 | """ 85 | 86 | def ends_in_multiline_string() -> bool: 87 | """ 88 | ``True`` if we're inside a multiline string at the end of the text. 89 | """ 90 | delims = _multiline_string_delims.findall(document.text) 91 | opening = None 92 | for delim in delims: 93 | if opening is None: 94 | opening = delim 95 | elif delim == opening: 96 | opening = None 97 | return bool(opening) 98 | 99 | if "\n" in document.text or ends_in_multiline_string(): 100 | return True 101 | 102 | def line_ends_with_colon() -> bool: 103 | return document.current_line.rstrip()[-1:] == ":" 104 | 105 | # If we just typed a colon, or still have open brackets, always insert a real newline. 106 | if ( 107 | line_ends_with_colon() 108 | or ( 109 | document.is_cursor_at_the_end 110 | and has_unclosed_brackets(document.text_before_cursor) 111 | ) 112 | or document.text.startswith("@") 113 | ): 114 | return True 115 | 116 | # If the character before the cursor is a backslash (line continuation 117 | # char), insert a new line. 118 | elif document.text_before_cursor[-1:] == "\\": 119 | return True 120 | 121 | return False 122 | 123 | 124 | _T = TypeVar("_T", bound=Callable[[MouseEvent], None]) 125 | 126 | 127 | def if_mousedown(handler: _T) -> _T: 128 | """ 129 | Decorator for mouse handlers. 130 | Only handle event when the user pressed mouse down. 131 | 132 | (When applied to a token list. Scroll events will bubble up and are handled 133 | by the Window.) 134 | """ 135 | 136 | def handle_if_mouse_down(mouse_event: MouseEvent): 137 | if mouse_event.event_type == MouseEventType.MOUSE_DOWN: 138 | return handler(mouse_event) 139 | else: 140 | return NotImplemented 141 | 142 | return cast(_T, handle_if_mouse_down) 143 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | import os 3 | import re 4 | 5 | with open("README.md", "r") as fh: 6 | long_description = fh.read() 7 | 8 | def get_version(package): 9 | """ 10 | Return package version as listed in `__version__` in `__init__.py`. 11 | """ 12 | path = os.path.join(os.path.dirname(__file__), package, "__init__.py") 13 | with open(path, "rb") as f: 14 | init_py = f.read().decode("utf-8") 15 | return re.search("__version__ = ['\"]([^'\"]+)['\"]", init_py).group(1) 16 | 17 | 18 | install_requirements = [ 19 | 'cyanodbc >= 0.0.3', 20 | 'prompt_toolkit >= 3.0.5', 21 | 'Pygments>=2.6.1', 22 | 'sqlparse >= 0.3.1', 23 | 'configobj >= 5.0.6', 24 | 'click >= 7.1.2', 25 | 'cli_helpers >= 2.0.1' 26 | ] 27 | 28 | setuptools.setup( 29 | name = "odbcli", 30 | version = get_version("odbcli"), 31 | author = "Oliver Gjoneski", 32 | author_email = "ogjoneski@gmail.com", 33 | description = "ODBC Client", 34 | license = 'BSD-3', 35 | long_description = long_description, 36 | long_description_content_type = "text/markdown", 37 | install_requires = install_requirements, 38 | url = "https://github.com/pypa/odbc-cli", 39 | scripts=[ 40 | 'odbc-cli' 41 | ], 42 | packages = setuptools.find_packages(), 43 | include_package_data = True, 44 | classifiers=[ 45 | "Programming Language :: Python :: 3", 46 | "License :: OSI Approved :: BSD License", 47 | "Operating System :: OS Independent", 48 | ], 49 | # As python prompt toolkit 50 | python_requires = '>=3.6.1', 51 | ) 52 | -------------------------------------------------------------------------------- /tests/test_dbmetadata.py: -------------------------------------------------------------------------------- 1 | from odbcli.dbmetadata import DbMetadata 2 | import pytest 3 | 4 | def test_catalogs(): 5 | db = DbMetadata() 6 | cats = ["a", "b", "c", ""] 7 | res = db.get_catalogs() 8 | assert res is None 9 | db.extend_catalogs(cats) 10 | res = db.get_catalogs(obj_type = "table") 11 | assert res == cats 12 | res = db.get_catalogs(obj_type = "view") 13 | assert res == cats 14 | res = db.get_catalogs(obj_type = "function") 15 | assert res == cats 16 | 17 | def test_schemas(): 18 | db = DbMetadata() 19 | cats = ["a", "b", "c", ""] 20 | schemas = ["A", "B", "C", ""] 21 | db.extend_catalogs(cats) 22 | res = db.get_schemas(catalog = "d") 23 | assert res is None 24 | res = db.get_schemas(catalog = "a") 25 | assert res == [] 26 | db.extend_schemas(catalog = "d", names = schemas) 27 | db.extend_schemas(catalog = "", names = schemas) 28 | res = db.get_catalogs(obj_type = "table") 29 | assert "d" in res 30 | res = db.get_schemas(catalog = "d") 31 | assert res == schemas 32 | res = db.get_schemas(catalog = "") 33 | assert res == schemas 34 | 35 | def test_objects(): 36 | db = DbMetadata() 37 | cats = ["a", "b", "c", ""] 38 | schemas = ["A", "B", "C", ""] 39 | tables = ["t1", "t2", "t3"] 40 | views = ["v1", "v2", "v3"] 41 | db.extend_catalogs(cats) 42 | db.extend_schemas(catalog = "a", names = schemas) 43 | res = db.get_objects(catalog = "a", schema = "D") 44 | assert res is None 45 | res = db.get_objects(catalog = "a", schema = "A") 46 | assert res == [] 47 | db.extend_objects(catalog = "a", schema = "A", names = tables, obj_type = "table") 48 | db.extend_objects(catalog = "a", schema = "D", names = tables, obj_type = "table") 49 | res = db.get_objects(catalog = "a", schema = "A") 50 | assert res == tables 51 | res = db.get_objects(catalog = "a", schema = "D") 52 | assert res == tables 53 | res = db.get_objects(catalog = "a", schema = "D", obj_type = "view") 54 | assert res == [] 55 | db.extend_objects(catalog = "a", schema = "A", names = views, obj_type = "view") 56 | res = db.get_objects(catalog = "a", schema = "A", obj_type = "view") 57 | assert res == views 58 | res = db.get_objects(catalog = "a", schema = "A", obj_type = "table") 59 | assert res == tables 60 | -------------------------------------------------------------------------------- /tests/test_parseutils.py: -------------------------------------------------------------------------------- 1 | from odbcli.completion.parseutils.tables import extract_table_identifiers, TableReference 2 | from sqlparse import parse 3 | import pytest 4 | 5 | def test_get_table_identifiers(): 6 | qry = "SELECT a.col1, b.col2 " \ 7 | "FROM abc.def.ghi AS a " \ 8 | "INNER JOIN jkl.mno.pqr AS b ON a.id_one = b.id_two" 9 | parsed = parse(qry)[0] 10 | res = list(extract_table_identifiers(parsed)) 11 | expected = [ 12 | TableReference(None, "a", "col1", None, False), 13 | TableReference(None, "b", "col2", None, False), 14 | TableReference("abc", "def", "ghi", "a", False), 15 | TableReference("jkl", "mno", "pqr", "b", False), 16 | ] 17 | 18 | assert res == expected 19 | -------------------------------------------------------------------------------- /tests/test_sqlcompletion.py: -------------------------------------------------------------------------------- 1 | from odbcli.completion.sqlcompletion import SqlStatement 2 | import pytest 3 | 4 | 5 | @pytest.mark.parametrize( 6 | "before_cursor, expected", 7 | [ 8 | (" ", (None, None)), 9 | ("abc", (None, None)), 10 | ("abc.", (None, "abc")), 11 | ("abc.def", (None, "abc")), 12 | ("abc.def.", ("abc", "def")), 13 | ("abc.def.ghi", ("abc", "def")) 14 | ], 15 | ) 16 | def test_get_identifier_parents(before_cursor, expected): 17 | stmt = SqlStatement( 18 | full_text = "SELECT * FROM abc.def.ghi", 19 | text_before_cursor = before_cursor) 20 | 21 | assert stmt.get_identifier_parents() == expected 22 | --------------------------------------------------------------------------------