├── .github
└── workflows
│ └── publish.yml
├── .gitignore
├── .travis.yml
├── LICENSE
├── MANIFEST.in
├── README.md
├── ci
└── travis
│ └── deploy.sh
├── odbc-cli
├── odbcli
├── __init__.py
├── __main__.py
├── app.py
├── cli.py
├── completion
│ ├── __init__.py
│ ├── mssqlcompleter.py
│ ├── mssqlliterals
│ │ ├── __init__.py
│ │ ├── main.py
│ │ └── sqlliterals.json
│ ├── parseutils
│ │ ├── __init__.py
│ │ ├── ctes.py
│ │ ├── meta.py
│ │ ├── tables.py
│ │ └── utils.py
│ ├── prioritization.py
│ └── sqlcompletion.py
├── config.py
├── conn.py
├── dbmetadata.py
├── disconnect_dialog.py
├── filters.py
├── layout.py
├── loginprompt.py
├── odbclirc
├── odbcstyle.py
├── preview.py
├── sidebar.py
└── utils.py
├── setup.py
└── tests
├── test_dbmetadata.py
├── test_parseutils.py
└── test_sqlcompletion.py
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: odbcli
2 | on: push
3 |
4 | jobs:
5 | build-n-publish:
6 | runs-on: ubuntu-latest
7 |
8 | steps:
9 | - uses: actions/checkout@master
10 | - name: Set up Python 3.7
11 | uses: actions/setup-python@v3
12 | with:
13 | python-version: "3.7"
14 |
15 | - name: Update dev version
16 | if: github.event_name != 'push' || startsWith(github.ref, 'refs/tags') != true
17 | run: |
18 | sed -i "/__version__ = / s/\"$/.20${{github.run_number}}\"/" odbcli/__init__.py
19 |
20 | - name: Build and test
21 | run: |
22 | python -m pip install --upgrade --upgrade-strategy eager . twine pytest
23 | pytest .
24 | python setup.py sdist -d dist/
25 |
26 | - name: Publish to Test PyPI
27 | if: github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags')
28 | uses: pypa/gh-action-pypi-publish@release/v1
29 | with:
30 | password: ${{ secrets.TEST_PYPI_API_TOKEN }}
31 | repository-url: https://test.pypi.org/legacy/
32 |
33 | - name: Publish to PyPI
34 | if: startsWith(github.ref, 'refs/tags')
35 | uses: pypa/gh-action-pypi-publish@release/v1
36 | with:
37 | password: ${{ secrets.PYPI_API_TOKEN }}
38 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | sudo: false
2 | cache: pip
3 | language: python
4 |
5 | env:
6 | global:
7 | - secure: YGpGnxf9XMnwFiOyXKvl8XiDodhPIN7+PzOytJ0bodNvhriaua8Y54NT04WXLzpcg59t0Bp0Yd6gUERuyZ4cslBBpHtfl74iUOJPOEk/V7H0562C7TAxXxI9+XltYBhnv55Bxq/4B1Rj0YQxJ0Dl3CMfJGovAPs6zLUA7p2fz+kkhIDcc4DC2l7DE3rYhUsRCYakjaCpGSfPT1LSOBheGPWN3taPnANr0M9pTDI+Ej+kpfip7Sxna+phKZjvy9gsdhr4Wbca52826wj29E3B2Z4qw1wBVTqCzT6oPTvLt2++etbAfDkGohXy0zBOHDy2nN9oILIci5GMYd6qivQTu2gQBM7n0woMjk3NNe6buu23t7DpI9EuVpEl/3ocPdqDudaR9kW5UyWYJd0lvEJeTUolKjDGnUiDEh/tDAvgXtWiqi8DTZJTnwDwFpGB8eKNPqJhbUL9lj7HtfyElwTzJyHViwJRvAtvbMCtLroEVZau41+cTS2aIVkJLUP7ECpIKXWNGzmhfTBWsuF2aKnEH3onG84/QbuS8ADFD5XAKUP9fyq7qK7sf2loDrhSrGG2f8laIrL53ZiUgQItmcHBqW64eXILJwHN4loucc+X/gS27mNPZQZGxdF1UmWTUUnzVCTmrrL8hQqVi4Kue7j2TLtX6GCaocZ+VTS2qVWO3RI=
8 | - secure: DAQBtOrsIbNav3Ac8FLKjOQNfnlPMX+cXRk4SfjvfxJciA8W16Aa2ZyYhC0AZiEBQbMwslYS2WyYNU3+NsQfoCJ+KoWYbnfT/qfvn+K+f0uKZTazZaOHSJA4tOs6kA88L1X5yHOwMalQqyRUEkQ1lHxcGxmjMFDi+FdsCLDVQIuLaAn+Apk3TNGXKzV7IaYNkHO5kQmTrl4/v8epJcD35lo5Yw8SkUGgo0rPosecXvZU3ZoR9CpFWfBid+NECYrv8Cp7syTT9srp0y95sE38wCwr9nttNmcCKe3fLGygh0RBLyrkNl6Ajpk7qJVaa2/McIO3+4qekXlCt7dX4UpEWUH3wDjHGb//JdsaLgIMTLWP393hm4g1SJQgnPCjVi/4hxfcGfF08dnQw4JYAHC4z1qjae5A6zoI30JHOqwH4Yu3Rj9GhFhsWNA7I2kNXg/el+df23jA7mxWfXwuY5YLck156NXij8vhLYGxltgDEtmp6U1bnppQpaJKxn/9dcXZCsbJb+x8FZ0OF3lkq82vnnXAZuQ41gTosuSg1w9DQKF6OImAG0KVXDwd5qTwp8thXdrVSRXKcUltvIT9+W7kyMEKHnusLJruOvtAuCNLArSE5keJdqtb+POLfOp9ZrUObs0j/77zfZH85VmjWspWkn9rZ1v10gYHZv6Lu8NHbCM=
9 | - secure: bJrbXXuC9O/2YazHi9D95rWiPPDyDlVUE1i9WSos3lMAbqapn2vKL2gRosrM1NcIVIhnAzwi7O8ngQRiKID7nWF1xYGhapZcIya3q6z40m0HNQus+XngzSCTNaWVCMqEASDGluHOLlc/Vnbip/3tcgUSlGHk6P1p/zzoUbMKOa0igcnCG0kxR6BkcrPL1qIWdCMTDkGlmT7IFHF4ilFkhGJWssjR4mFFZX2awJJtQ2BmudsUKeytNjUI3/yQyV3eoQo2EXQDcU34ObkZFZkymdPBqfy3sSw3To2JLc4aiq/iZzIfItUaUcWS7T9HbhUD1aDG4pWWhqCBdPDbwcc0XklgX5eL9ZXI9s++gKMO7V5nu4GSE39ShYruyq8mPiRVGO4oCUFPP1N8ZMFaXLEMg5Wa4s+VxuHSMB70l5VeL7UvorJ+Ij4djBtn2RF6IYymvpDH3q5n4W2BciVg7LRZ1GX8TtSZNae4O09qYQhMGQNIKCNykbvx3DJamcI1q54kJTCjub9RVcvzUxap0PgTxjHrGB7M90wg9bHmvTaJvl7fkFC8yrtOQ/vjwttiq5X7SCrEHaBSV9TEbLOB3jB2bIfLQ92T2USjbExetOartz/P+X1dreECHXhk9g55qAvob4k3uCE3K+WAEdbVw39mqhfUPz0iuKqTsrJJ5fNyX2Y=
10 | - secure: d6W0/c04ea+0Lb5FXzzVcV4yVSqEMwrB1DOEpRapTP50a6JumziLKhIyDyzsJMgqVmkGQsAgfDNtGOZ1k/UiA3rjSXlQ5YgIwQpr0xgntKDq1/Ha02xxBpXfaS+ts3cnygldOGKF7jtPOJ+sb3hSdVmX1+q25OoyW8VtUT3GVtK0vERKxePdNzwOmrUXihYsvMMDrebmvcQvXHyk0UhtKnPlYVfyzcqoRBhlKUkUfU1LhLDQvewWcwwg7fC+7eK3Njx7EyIZ2sxU2IzDsE933YksVmJKowZOZtJsfOh6e1dspmP5Id9bGl4wTH68IzoFXzEi+/RAo68wXRjL39G6/WaRqph7YQHtLfqak1zLM7Gh0jxYSseWyYAW29RbLV71H4isXG8X6lOQE8hq9VxLaXoWx9HHaCJWuQUZVGlcnlvOcOxlr+FKmG29Z3zm9RtsZuGuZXm3+Ndf32ZBebGr5RQg5R4kkb2R/BIcm/+FBYRdCO6zjx894VXDUdUDuuwSPj0V9EvhUEEcKxwjxezIvcyE6/GC+FirhESSU/b5ufWEIW2O3TAMJxthUN5QX9JUV4+4rNhW70Y968u/BbZBCIC6fYl4UXFNl3f+I09hs6EhMVWK88ytszn+L0dACnIpedKVeJQ46GNkiX5pwve/oLtaww1cHGWP+WJNNGQmACc=
11 |
12 | matrix:
13 | include:
14 | - python: 3.7
15 | dist: xenial
16 | sudo: required
17 |
18 | script:
19 | - cd $TRAVIS_BUILD_DIR
20 | - |
21 | if [[ "$TRAVIS_TAG" == "" ]]; then
22 | sed -i "/__version__ = / s/\"$/.20$TRAVIS_BUILD_NUMBER\"/" odbcli/__init__.py
23 | fi
24 | - pip install --upgrade --upgrade-strategy eager . twine pytest
25 | - pytest .
26 | - python setup.py sdist
27 |
28 | deploy:
29 | - provider: script
30 | skip_cleanup: true
31 | script: $TRAVIS_BUILD_DIR/ci/travis/deploy.sh
32 | on:
33 | branch: master
34 | - provider: script
35 | skip_cleanup: true
36 | script: $TRAVIS_BUILD_DIR/ci/travis/deploy.sh
37 | on:
38 | tags: true
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | All rights reserved.
2 |
3 | Redistribution and use in source and binary forms, with or without modification,
4 | are permitted provided that the following conditions are met:
5 |
6 | * Redistributions of source code must retain the above copyright notice, this
7 | list of conditions and the following disclaimer.
8 |
9 | * Redistributions in binary form must reproduce the above copyright notice, this
10 | list of conditions and the following disclaimer in the documentation and/or
11 | other materials provided with the distribution.
12 |
13 | * Neither the name of the {organization} nor the names of its
14 | contributors may be used to endorse or promote products derived from
15 | this software without specific prior written permission.
16 |
17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
21 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
24 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include odbcli/completion/mssqlliterals/sqlliterals.json
2 | include odbcli/odbclirc
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # odbc-cli
2 |
3 | *Please note: this package should be considered "alpha" - while you are more than welcome to use it, you should expect that getting it to work for you will require quite a bit of self-help on your part. At the same time, it may be a great opportunity for those that want to contribute.*
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | [**odbc-cli**](https://github.com/detule/odbc-cli) is an interactive command line query tool intended to work for DataBase Management Systems (DBMS) supported by ODBC drivers.
13 |
14 | As is the case with the [remaining clients](https://github.com/dbcli/) derived from the [python prompt toolkit library](https://github.com/prompt-toolkit/python-prompt-toolkit), **odbc-cli** also supports a rich interactive command line experience, with features such as auto-completion, syntax-highlighting, multi-line queries, and query-history.
15 |
16 | Beyond these, some distinguishing features of **odbc-cli** are:
17 |
18 | - **Multi DBMS support**: In addition to supporting connections to multiple DBMS, with **odbc-cli** you can connect to, and query multiple databases in the same session.
19 | - **An integrated object browser**: Navigate between connections and objects within a database.
20 | - **Small footprint and excellent performance**: One of the main motivations is to reduce both the on-disk, as well as in-memory footprint of the [existing Microsoft SQL Server client](https://github.com/dbcli/mssql-cli/), while at the same time improve query execution, and time spent retrieving results.
21 | - **Out-of-database auto-completion**: Mostly relevant to SQL Server users, but auto-completion is "aware" of schema and table structure outside of the currently connected catalog / database.
22 |
23 | ## Installing and OS support
24 |
25 | The assumption is that the starting point is a box with a working ODBC setup. This means a driver manager (UnixODBC, for example), together with ODBC drivers that are appropriate to the DBM Systems you intend to connect to.
26 |
27 | To install the latest version of the package marked as *stable*, simply:
28 |
29 | ```sh
30 | python -m pip install odbcli
31 | ```
32 |
33 | *Development* versions, tracking the tip of the master branch, are hosted on Test Pypi, and can be installed, for example by:
34 |
35 | ```sh
36 | python -m pip install --index-url https://test.pypi.org/simple/ odbcli
37 | ```
38 |
39 | Notes:
40 | * In theory, this package should work under Windows, MacOS, as well as Linux. I can only test Linux; help testing and developing on the other platforms (as well as Linux) is very much welcome.
41 | * The main supporting package, [**cyanodbc**](https://github.com/cyanodbc/cyanodbc) comes as a pre-compiled wheel. It requires a modern C++ library supporting the C++14 standard. The cyanodbc Linux wheel is built on Ubuntu 16 - not exactly bleeding edge. Anything newer should be fine.
42 | * As of https://github.com/detule/odbc-cli/commit/bea22885d0483de0c1899ebc26ff853568b0e417, **odbc-cli** requires `cyanodbc` version [0.0.2.136](https://test.pypi.org/project/Cyanodbc/0.0.2.136/#files) or newer.
43 |
44 | ## Usage
45 |
46 | See the [Usage section here](https://detule.github.io/odbc-cli/index.html#Usage).
47 |
48 | ## Supported DBMS
49 |
50 | I have had a chance to test connectivity and basic functionality to the following DBM Systems:
51 |
52 | * **Microsoft SQL Server**
53 | Support and usability here should be furthest along. While I encounter (and fix) an occasional issue, I use this client in this capacity daily.
54 |
55 | Driver notes:
56 | * OEM Driver: No known issues (I test with driver version 17.5).
57 | * FreeTDS: Please use version 1.2 or newer for optimal performance (older versions do not support the SQLColumns API endpoint applied to tables out-of-currently-connected-catalog).
58 |
59 | * **MySQL**
60 | I have had a chance to test connectivity and basic functionality, but contributor help very much appreciated.
61 |
62 | * **SQLite**
63 | I have had a chance to test connectivity and basic functionality, but contributor help very much appreciated.
64 |
65 | * **PostgreSQL**
66 | I have had a chance to test connectivity and basic functionality, but contributor help very much appreciated.
67 |
68 | Driver notes:
69 | * Please consider using [psqlODBC 12.01](https://odbc.postgresql.org/docs/release.html) or newer for optimal performance (older versions, when used with a PostgreSQL 12.0, seem to have a documented bug when calling into SQLColumns).
70 |
71 | * **Snowflake**
72 | I have had a chance to test connectivity and basic functionality, but contributor help very much appreciated.
73 |
74 | Driver notes:
75 | * As of version 2.20 of their ODBC driver, consider specifying the `Database` field in the DSN configuration section in your INI files. If no `Database` is specified when connecting, their driver will report the empty string - despite being attached to a particlar catalog. Subsequently, post-connection specifying the database using `USE` works as expected.
76 |
77 | * **Other** DMB Systems with ODBC drivers not mentioned above should work with minimal, or hopefully no additional, configuration / effort.
78 |
79 | ## Reporting issues
80 |
81 | The best feature - multi DBMS support, is also a curse from a support perspective, as there are too-many-to-count combinations of:
82 |
83 | * Client platform (ex: Debian 10)
84 | * Data base system (ex: SQL Server)
85 | * Data base version (ex: 19)
86 | * ODBC driver manager (ex: unixODBC)
87 | * ODBC driver manager version (ex: 2.3.x)
88 | * ODBC driver (ex: FreeTDS)
89 | * ODBC driver version (ex: 1.2.3)
90 |
91 | that could be specific to your setup, contributing to the problem and making it difficult to replicate. Please consider including all of this information when reporting the issue, but above all be prepared that I may not be able to replicate and fix your issue (and therefore, hopefully you can contribute / code-up a solution). Since the use case for this client is so broad, the only way I see this project having decent support is if we build up a critical mass of user/developers.
92 |
93 | ## Troubleshooting
94 |
95 | ### Listing connections and connecting to databases
96 |
97 | The best way to resolve connectivity issues is to work directly in a python console. In particular, try working directly with the `cyanodbc` package in an interactive session.
98 |
99 | * When starting the client, **odbc-cli** queries the driver manager for a list of available connections by executing:
100 |
101 | ```
102 | import cyanodbc
103 | cyanodbc.datasources()
104 | ```
105 |
106 | Make sure this command returns a coherent output / agrees with your expectations before attempting anything else. If it does not, consult the documentaion for your driver manager and make sure all the appropriate INI files are populated accordingly.
107 |
108 | * If for example, you are attempting to connect to a DSN called `postgresql_db` - recall this should be defined and configured in the INI configuration file appropriate to your driver manager, in the background, **odbc-cli** attempts to establish a connection with a connection string similar to:
109 |
110 | ```
111 | import cyanodbc
112 | conn = cyanodbc.connect("DSN=postgresql_db;UID=postgres;PWD=password")
113 | ```
114 |
115 | If experiencing issues connecting to a database, make sure you can establish a connection using the method above, before moving on to troubleshoot other parts of the client.
116 |
117 | ## Acknowledgements
118 |
119 | This project would not be possible without the most excellent [python prompt toolkit library](https://github.com/prompt-toolkit/python-prompt-toolkit). In addition, idea and code sharing between the [clients that leverage this library](https://github.com/dbcli/) is rampant, and this project is no exception - a big thanks to all the `dbcli` contributors.
120 |
--------------------------------------------------------------------------------
/ci/travis/deploy.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -ue
2 |
3 | if ! [[ "$TRAVIS_TAG" == "" ]]
4 | then
5 | echo "Deploying wheel to test pypi"
6 | twine upload --repository-url https://test.pypi.org/legacy/ -u $TEST_PYPI_USERNAME -p $TEST_PYPI_PASSWORD $TRAVIS_BUILD_DIR/dist/odbcli*.tar.gz
7 | echo "Deploying wheel to pypi"
8 | twine upload -u $PYPI_USERNAME -p $PYPI_PASSWORD $TRAVIS_BUILD_DIR/dist/odbcli*.tar.gz
9 | else
10 | echo "Deploying wheel to test pypi"
11 | twine upload --repository-url https://test.pypi.org/legacy/ -u $TEST_PYPI_USERNAME -p $TEST_PYPI_PASSWORD $TRAVIS_BUILD_DIR/dist/odbcli*.tar.gz
12 | fi
13 |
--------------------------------------------------------------------------------
/odbc-cli:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | SOURCE="${BASH_SOURCE[0]}"
4 | while [ -h "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
5 | DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )"
6 | SOURCE="$(readlink "$SOURCE")"
7 | [[ $SOURCE != /* ]] && SOURCE="$DIR/$SOURCE" # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
8 | done
9 | DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )"
10 |
11 | # Set the python io encoding to UTF-8 by default if not set.
12 | if [ -z ${PYTHONIOENCODING+x} ]; then export PYTHONIOENCODING=UTF-8; fi
13 |
14 | export PYTHONPATH="${DIR}:${PYTHONPATH}"
15 |
16 | python -m odbcli.__main__ "$@"
17 |
--------------------------------------------------------------------------------
/odbcli/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.0.6"
2 |
--------------------------------------------------------------------------------
/odbcli/__main__.py:
--------------------------------------------------------------------------------
1 | from .cli import main
2 |
3 | if __name__ == "__main__":
4 | main()
5 |
--------------------------------------------------------------------------------
/odbcli/app.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar
5 | from prompt_toolkit.enums import EditingMode
6 | from prompt_toolkit.key_binding import KeyBindings, merge_key_bindings
7 | from prompt_toolkit.keys import Keys
8 | from prompt_toolkit.key_binding.bindings.auto_suggest import load_auto_suggest_bindings
9 | from prompt_toolkit.application import Application
10 | from prompt_toolkit.key_binding.bindings.focus import focus_next
11 | from prompt_toolkit.filters import Condition, has_focus
12 | from logging.handlers import RotatingFileHandler
13 | from cyanodbc import datasources
14 | from .sidebar import myDBConn, myDBObject
15 | from .conn import sqlConnection
16 | from .completion.mssqlcompleter import MssqlCompleter
17 | from .config import get_config, config_location, ensure_dir_exists
18 | from .odbcstyle import style_factory
19 | from .layout import sqlAppLayout
20 |
21 | class ExitEX(Exception):
22 | pass
23 |
24 | class sqlApp:
25 | def __init__(
26 | self,
27 | odbclirc_file = None
28 | ) -> None:
29 | c = self.config = get_config(odbclirc_file)
30 | self.initialize_logging()
31 | self.set_default_pager(c)
32 | self.mouse_support: bool = c["main"].as_bool("mouse_support")
33 | self.fetch_chunk_multiplier = c["main"].as_int("fetch_chunk_multiplier")
34 | self.preview_limit_rows = c["main"].as_int("preview_limit_rows")
35 | self.preview_chunk_size = c["main"].as_int("preview_fetch_chunk_size")
36 | self.pager_reserve_lines = c["main"].as_int("pager_reserve_lines")
37 | self.table_format = c["main"]["table_format"]
38 | self.timing_enabled = c["main"].as_bool("timing")
39 | self.syntax_style = c["main"]["syntax_style"]
40 | self.cli_style = c["colors"]
41 | self.multiline: bool = c["main"].as_bool("multi_line")
42 | self.min_num_menu_lines = c["main"].as_int("min_num_menu_lines")
43 |
44 | self.show_exit_confirmation: bool = False
45 | self.exit_message: str = "Do you really want to exit?"
46 |
47 | self.show_expanding_object: bool = False
48 |
49 | self.show_sidebar: bool = True
50 | self.show_login_prompt: bool = False
51 | self.show_preview: bool = False
52 | self.show_disconnect_dialog: bool = False
53 | self._active_conn = None
54 | self.obj_list = []
55 | # Flag to signal to some of the prompt toolkit structures that we need
56 | # to traverse the obj_list anew to list all the objects in the sidebar.
57 | # Added for efficiency (no need to traverse unless necessary). Updated
58 | # from the main thread always, so no need for locking.
59 | self.obj_list_changed: bool = True
60 | # This field is a list with two elements. The first is the index of
61 | # the currently selected object (0-indexed). The second is the
62 | # index of the currently selected object in the
63 | # list of objects where each object is counted with length of characters
64 | # in name + 1 multiplicity. So for example a list of objects
65 | # A
66 | # AB
67 | # ABC
68 | # ABCD
69 | # where "ABC" is selected would present the index as 5 ([1+1] + [2+1]).
70 | # This is used to track the cursor position in the sidebar document
71 | # It is recorded here, rather than elsewhere because we can track it
72 | # here far more efficiently (select_next, and select_previous).
73 | # It is important that all methods of this class that manipulate the
74 | # currently selected object also update this index.
75 | self._selected_obj_idx = [0, 0]
76 | dsns = list(datasources().keys())
77 | if len(dsns) < 1:
78 | sys.exit("No datasources found ... exiting.")
79 | for dsn in dsns:
80 | self.obj_list.append(myDBConn(
81 | my_app = self,
82 | conn = sqlConnection(dsn = dsn),
83 | name = dsn,
84 | otype = "Connection"))
85 | for i in range(len(self.obj_list) - 1):
86 | self.obj_list[i].next_object = self.obj_list[i + 1]
87 | # Loop over side-bar when moving past the element on the bottom
88 | self.obj_list[len(self.obj_list) - 1].next_object = self.obj_list[0]
89 | self._selected_object = self.obj_list[0]
90 | self.completer = MssqlCompleter(smart_completion = True, get_conn = lambda: self.active_conn)
91 |
92 | self.application = self._create_application()
93 |
94 | @property
95 | def active_conn(self) -> sqlConnection:
96 | return self._active_conn
97 |
98 | @active_conn.setter
99 | def active_conn(self, conn: sqlConnection) -> None:
100 | self._active_conn = conn
101 |
102 | @property
103 | def selected_object(self) -> myDBObject:
104 | return self._selected_object
105 |
106 | @selected_object.setter
107 | def selected_object(self, obj) -> None:
108 | """ Avoid using / computationally expensive.
109 | Instead try using select_next / select_previous if possible.
110 | Will update _selected_obj_idx appropriately.
111 | """
112 | cursor = 0
113 | idx = 0
114 | o = self.obj_list[0]
115 | self._selected_object = obj
116 | while o is not self._selected_object:
117 | if not o.next_object:
118 | raise IndexError
119 | cursor += len(o.name) + 1
120 | idx += 1
121 | o = o.next_object
122 | self._selected_obj_idx = [idx, cursor]
123 |
124 | def select(self, idx) -> None:
125 | """ Select the [i]-th object in the list. Will also update
126 | _selected_obj_idx appropriately.
127 | """
128 | counter = 0
129 | cursor = 0
130 | o = self.obj_list[0]
131 | while counter < idx:
132 | if not o.next_object:
133 | raise IndexError
134 | counter += 1
135 | cursor += len(o.name) + 1
136 | o = o.next_object
137 | self._selected_object = o
138 | self._selected_obj_idx = [idx, cursor]
139 |
140 | @property
141 | def selected_object_idx(self):
142 | return self._selected_obj_idx
143 |
144 | def select_next(self) -> None:
145 | self._selected_object = self.selected_object.next_object
146 |
147 | def select_previous(self) -> None:
148 | obj = self.selected_object.parent if self.selected_object.parent is not None else self.obj_list[0]
149 | while obj.next_object is not self.selected_object:
150 | obj = obj.next_object
151 | self._selected_object = obj
152 |
153 | @property
154 | def editing_mode(self) -> EditingMode:
155 | return self.application.editing_mode
156 |
157 | @editing_mode.setter
158 | def editing_mode(self, value: EditingMode) -> None:
159 | app = self.application
160 | app.editing_mode = value
161 |
162 | @property
163 | def vi_mode(self) -> bool:
164 | return self.editing_mode == EditingMode.VI
165 |
166 | @vi_mode.setter
167 | def vi_mode(self, value: bool) -> None:
168 | if value:
169 | self.editing_mode = EditingMode.VI
170 | else:
171 | self.editing_mode = EditingMode.EMACS
172 |
173 | def set_default_pager(self, config):
174 | configured_pager = config["main"].get("pager")
175 | os_environ_pager = os.environ.get("PAGER")
176 |
177 | if configured_pager:
178 | self.logger.info(
179 | 'Default pager found in config file: "%s"', configured_pager
180 | )
181 | os.environ["PAGER"] = configured_pager
182 | elif os_environ_pager:
183 | self.logger.info(
184 | 'Default pager found in PAGER environment variable: "%s"',
185 | os_environ_pager,
186 | )
187 | os.environ["PAGER"] = os_environ_pager
188 | else:
189 | self.logger.info(
190 | "No default pager found in environment. Using os default pager"
191 | )
192 |
193 | # Set default set of less recommended options, if they are not already set.
194 | # They are ignored if pager is different than less.
195 | if not os.environ.get("LESS"):
196 | os.environ["LESS"] = "-SRXF"
197 |
198 | def initialize_logging(self):
199 | log_file = self.config['main']['log_file']
200 | if log_file == 'default':
201 | log_file = config_location() + 'odbcli.log'
202 | ensure_dir_exists(log_file)
203 | log_level = self.config['main']['log_level']
204 |
205 | # Disable logging if value is NONE by switching to a no-op handler.
206 | # Set log level to a high value so it doesn't even waste cycles getting
207 | # called.
208 | if log_level.upper() == 'NONE':
209 | handler = logging.NullHandler()
210 | else:
211 | # creates a log buffer with max size of 20 MB and 5 backup files
212 | handler = RotatingFileHandler(os.path.expanduser(log_file),
213 | encoding='utf-8', maxBytes=1024*1024*20, backupCount=5)
214 |
215 | level_map = {'CRITICAL': logging.CRITICAL,
216 | 'ERROR': logging.ERROR,
217 | 'WARNING': logging.WARNING,
218 | 'INFO': logging.INFO,
219 | 'DEBUG': logging.DEBUG,
220 | 'NONE': logging.CRITICAL
221 | }
222 |
223 | log_level = level_map[log_level.upper()]
224 |
225 | formatter = logging.Formatter(
226 | '%(asctime)s (%(process)d/%(threadName)s) '
227 | '%(name)s %(levelname)s - %(message)s')
228 |
229 | handler.setFormatter(formatter)
230 |
231 | root_logger = logging.getLogger('odbcli')
232 | root_logger.addHandler(handler)
233 | root_logger.setLevel(log_level)
234 |
235 | root_logger.info('Initializing odbcli logging.')
236 | root_logger.debug('Log file %r.', log_file)
237 | self.logger = logging.getLogger(__name__)
238 |
239 | def _create_application(self) -> Application:
240 | self.sql_layout = sqlAppLayout(my_app = self)
241 | kb = KeyBindings()
242 |
243 | confirmation_visible = Condition(lambda: self.show_exit_confirmation)
244 | @kb.add("c-q")
245 | def _(event):
246 | " Pressing Ctrl-Q or Ctrl-C will exit the user interface. "
247 | self.show_exit_confirmation = True
248 |
249 | @kb.add("y", filter=confirmation_visible)
250 | @kb.add("Y", filter=confirmation_visible)
251 | @kb.add("enter", filter=confirmation_visible)
252 | @kb.add("c-q", filter=confirmation_visible)
253 | def _(event):
254 | """
255 | Really quit.
256 | """
257 | event.app.exit(exception = ExitEX(), style="class:exiting")
258 |
259 | @kb.add(Keys.Any, filter=confirmation_visible)
260 | def _(event):
261 | """
262 | Cancel exit.
263 | """
264 | self.show_exit_confirmation = False
265 |
266 | # Global key bindings.
267 | @kb.add("tab", filter = Condition(lambda: self.show_preview or self.show_login_prompt))
268 | def _(event):
269 | event.app.layout.focus_next()
270 | @kb.add("f4")
271 | def _(event):
272 | " Toggle between Emacs and Vi mode. "
273 | self.vi_mode = not self.vi_mode
274 | # apparently ctrls does this
275 | @kb.add("c-t", filter = Condition(lambda: not self.show_preview))
276 | def _(event):
277 | """
278 | Show/hide sidebar.
279 | """
280 | self.show_sidebar = not self.show_sidebar
281 | if self.show_sidebar:
282 | event.app.layout.focus("sidebarbuffer")
283 | else:
284 | event.app.layout.focus_previous()
285 |
286 | sidebar_visible = Condition(lambda: self.show_sidebar and not self.show_expanding_object and not self.show_login_prompt and not self.show_preview) \
287 | & ~has_focus("sidebarsearchbuffer")
288 | @kb.add("up", filter=sidebar_visible)
289 | @kb.add("c-p", filter=sidebar_visible)
290 | @kb.add("k", filter=sidebar_visible)
291 | def _(event):
292 | " Go to previous option. "
293 | obj = self._selected_object
294 | self.select_previous()
295 | inc = len(self.selected_object.name) + 1 # newline character
296 | if obj is self.obj_list[0]:
297 | idx = 0
298 | cursor = 0
299 | while obj is not self._selected_object:
300 | if not obj.next_object:
301 | raise IndexError
302 | cursor += len(obj.name) + 1
303 | idx += 1
304 | obj = obj.next_object
305 | self._selected_obj_idx = [idx, cursor]
306 | else:
307 | self._selected_obj_idx[0] -= 1
308 | self._selected_obj_idx[1] -= inc
309 |
310 | @kb.add("down", filter=sidebar_visible)
311 | @kb.add("c-n", filter=sidebar_visible)
312 | @kb.add("j", filter=sidebar_visible)
313 | def _(event):
314 | " Go to next option. "
315 | inc = len(self.selected_object.name) + 1 # newline character
316 | self.select_next()
317 | if self.selected_object is self.obj_list[0]:
318 | self._selected_obj_idx = [0, 0]
319 | else:
320 | self._selected_obj_idx[0] += 1
321 | self._selected_obj_idx[1] += inc
322 |
323 | @kb.add("enter", filter = sidebar_visible)
324 | def _(event):
325 | " If connection, connect. If table preview"
326 | obj = self.selected_object
327 | if type(obj).__name__ == "myDBConn" and not obj.conn.connected():
328 | self.show_login_prompt = True
329 | event.app.layout.focus(self.sql_layout.lprompt)
330 | if type(obj).__name__ == "myDBConn" and obj.conn.connected():
331 | # OG: some thread locking may be needed here
332 | self._active_conn = obj.conn
333 | elif obj.otype in ["table", "view", "function"]:
334 | self.show_preview = True
335 | self.show_sidebar = False
336 | event.app.layout.focus(self.sql_layout.preview)
337 |
338 | @kb.add("right", filter=sidebar_visible)
339 | @kb.add("l", filter=sidebar_visible)
340 | @kb.add(" ", filter=sidebar_visible)
341 | def _(event):
342 | " Select next value for current option. "
343 | obj = self.selected_object
344 | obj.expand()
345 | if type(obj).__name__ == "myDBConn" and not obj.conn.connected():
346 | self.show_login_prompt = True
347 | event.app.layout.focus(self.sql_layout.lprompt)
348 |
349 | @kb.add("left", filter=sidebar_visible)
350 | @kb.add("h", filter=sidebar_visible)
351 | def _(event):
352 | " Select next value for current option. "
353 | obj = self.selected_object
354 | if type(obj).__name__ == "myDBConn" and obj.conn.connected() and obj.children is None:
355 | self.show_disconnect_dialog = True
356 | event.app.layout.focus(self.sql_layout.disconnect_dialog)
357 | else:
358 | obj.collapse()
359 |
360 | auto_suggest_bindings = load_auto_suggest_bindings()
361 |
362 | return Application(
363 | layout = self.sql_layout.layout,
364 | key_bindings = merge_key_bindings([kb, auto_suggest_bindings]),
365 | enable_page_navigation_bindings = True,
366 | style = style_factory(self.syntax_style, self.cli_style),
367 | include_default_pygments_style = False,
368 | mouse_support = self.mouse_support,
369 | full_screen = False,
370 | editing_mode = EditingMode.VI if self.config["main"].as_bool("vi") else EditingMode.EMACS
371 | )
372 |
--------------------------------------------------------------------------------
/odbcli/cli.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """
3 | A simple example of a calculator program.
4 | This could be used as inspiration for a REPL.
5 | """
6 | import os
7 | from sys import stderr
8 | from time import time
9 | from cyanodbc import DatabaseError, datasources
10 | from click import echo_via_pager, secho
11 | from prompt_toolkit.patch_stdout import patch_stdout
12 | from prompt_toolkit.utils import get_cwidth
13 | from .app import sqlApp, ExitEX
14 | from .layout import sqlAppLayout
15 | from .conn import connStatus, executionStatus
16 |
17 |
18 | def main():
19 |
20 | my_app = sqlApp()
21 | # with patch_stdout():
22 | while True:
23 | try:
24 | app_res = my_app.application.run()
25 | except ExitEX:
26 | for i in range(len(my_app.obj_list)):
27 | my_app.obj_list[i].conn.close()
28 | return
29 | else:
30 | # If it's a preview query we need an indication
31 | # of where to run the query
32 | if(app_res[0] == "preview"):
33 | sql_conn = my_app.selected_object.conn
34 | else:
35 | sql_conn = my_app.active_conn
36 | if sql_conn is not None:
37 | #TODO also check that it is connected
38 | try:
39 | secho("Executing query...Ctrl-c to cancel", err = False)
40 | start = time()
41 | crsr = sql_conn.async_execute(app_res[1])
42 | execution = time() - start
43 | secho("Query execution...done", err = False)
44 | if(app_res[0] == "preview"):
45 | sql_conn.async_fetchall(my_app.preview_chunk_size,
46 | my_app.application)
47 | continue
48 | if my_app.timing_enabled:
49 | print("Time: %0.03fs" % execution)
50 |
51 | if sql_conn.execution_status == executionStatus.FAIL:
52 | err = sql_conn.execution_err
53 | secho("Query error: %s\n" % err, err = True, fg = "red")
54 | else:
55 | if crsr.description:
56 | cols = [col.name for col in crsr.description]
57 | else:
58 | cols = []
59 | if len(cols):
60 | ht = my_app.application.output.get_size()[0]
61 | sql_conn.async_fetchall(my_app.fetch_chunk_multiplier *
62 | (ht - 3 - my_app.pager_reserve_lines), my_app.application)
63 | formatted = sql_conn.formatted_fetch(ht - 3 - my_app.pager_reserve_lines, cols, my_app.table_format)
64 | echo_via_pager(formatted)
65 | else:
66 | secho("No rows returned\n", err = False)
67 | except KeyboardInterrupt:
68 | secho("Cancelling query...", err = True, fg = "red")
69 | sql_conn.cancel()
70 | secho("Query cancelled.", err = True, fg = "red")
71 | sql_conn.close_cursor()
72 |
--------------------------------------------------------------------------------
/odbcli/completion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dbcli/odbc-cli/2b061f4d700067ee3ceaca24da86af5cd3e8b21f/odbcli/completion/__init__.py
--------------------------------------------------------------------------------
/odbcli/completion/mssqlliterals/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dbcli/odbc-cli/2b061f4d700067ee3ceaca24da86af5cd3e8b21f/odbcli/completion/mssqlliterals/__init__.py
--------------------------------------------------------------------------------
/odbcli/completion/mssqlliterals/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | root = os.path.dirname(__file__)
5 | literal_file = os.path.join(root, 'sqlliterals.json')
6 |
7 | with open(literal_file) as f:
8 | literals = json.load(f)
9 |
10 |
11 | def get_literals(literal_type, type_=tuple):
12 | # Where `literal_type` is one of 'keywords', 'functions', 'datatypes',
13 | # returns a tuple of literal values of that type.
14 |
15 | return type_(literals[literal_type])
16 |
--------------------------------------------------------------------------------
/odbcli/completion/mssqlliterals/sqlliterals.json:
--------------------------------------------------------------------------------
1 | {
2 | "keywords": {
3 | "ADD": [],
4 | "ALL": [],
5 | "ALTER": [
6 | "APPLICATION ROLE",
7 | "ASSEMBLY",
8 | "ASYMMETRIC KEY",
9 | "AUTHORIZATION",
10 | "AVAILABILITY GROUP",
11 | "BROKER PRIORITY",
12 | "CERTIFICATE",
13 | "COLUMN ENCRYPTION KEY",
14 | "CREDENTIAL",
15 | "CRYPTOGRAPHIC PROVIDER",
16 | "DATABASE",
17 | "DATABASE AUDIT SPECIFICATION",
18 | "DATABASE ENCRYPTION KEY",
19 | "DATABASE HADR",
20 | "DATABASE SCOPED CREDENTIAL",
21 | "DATABASE SCOPED CONFIGURATION",
22 | "DATABASE SET Options",
23 | "ENDPOINT",
24 | "EVENT SESSION",
25 | "EXTERNAL DATA SOURCE",
26 | "EXTERNAL LIBRARY",
27 | "EXTERNAL RESOURCE POOL",
28 | "FULLTEXT CATALOG",
29 | "FULLTEXT INDEX",
30 | "FULLTEXT STOPLIST",
31 | "FUNCTION",
32 | "INDEX",
33 | "LOGIN",
34 | "MASTER KEY",
35 | "MESSAGE TYPE",
36 | "PARTITION FUNCTION",
37 | "PARTITION SCHEME",
38 | "PROCEDURE",
39 | "QUEUE",
40 | "REMOTE SERVICE BINDING",
41 | "RESOURCE GOVERNOR",
42 | "RESOURCE POOL",
43 | "ROLE",
44 | "ROUTE",
45 | "SCHEMA",
46 | "SEARCH PROPERTY LIST",
47 | "SECURITY POLICY",
48 | "SEQUENCE",
49 | "SERVER AUDIT",
50 | "SERVER AUDIT SPECIFICATION",
51 | "SERVER CONFIGURATION",
52 | "SERVER ROLE",
53 | "SERVICE",
54 | "SERVICE MASTER KEY",
55 | "SYMMETRIC KEY",
56 | "TABLE",
57 | "TRIGGER",
58 | "USER",
59 | "VIEW",
60 | "WORKLOAD GROUP",
61 | "XML SCHEMA COLLECTION"
62 | ],
63 | "ANY": [],
64 | "AND": [],
65 | "AS": [],
66 | "ASC": [],
67 | "AUTHORIZATION": [],
68 | "BACKUP": [],
69 | "BEGIN": [],
70 | "BETWEEN": [],
71 | "BREAK": [],
72 | "BROWSE": [],
73 | "BULK": [],
74 | "BY": [],
75 | "CASCADE": [],
76 | "CASE": [],
77 | "CHECK": [],
78 | "CHECKPOINT": [],
79 | "CLOSE": [],
80 | "CLUSTERED": [],
81 | "COALESCE": [],
82 | "COLLATE": [],
83 | "COLUMN": [],
84 | "COMMIT": [],
85 | "COMPUTE": [],
86 | "CONSTRAINT": [],
87 | "CONTAINS": [],
88 | "CONTAINSTABLE": [],
89 | "CONTINUE": [],
90 | "CONVERT": [],
91 | "CREATE": [
92 | "APPLICATION ROLE",
93 | "ASSEMBLY",
94 | "ASYMMETRIC KEY",
95 | "AUTHORIZATION",
96 | "AVAILABILITY GROUP",
97 | "BROKER PRIORITY",
98 | "CERTIFICATE",
99 | "COLUMN ENCRYPTION KEY",
100 | "CREDENTIAL",
101 | "CRYPTOGRAPHIC PROVIDER",
102 | "DATABASE",
103 | "DATABASE AUDIT SPECIFICATION",
104 | "DATABASE ENCRYPTION KEY",
105 | "DATABASE HADR",
106 | "DATABASE SCOPED CREDENTIAL",
107 | "DATABASE SCOPED CONFIGURATION",
108 | "DATABASE SET Options",
109 | "ENDPOINT",
110 | "EVENT SESSION",
111 | "EXTERNAL DATA SOURCE",
112 | "EXTERNAL LIBRARY",
113 | "EXTERNAL RESOURCE POOL",
114 | "FULLTEXT CATALOG",
115 | "FULLTEXT INDEX",
116 | "FULLTEXT STOPLIST",
117 | "FUNCTION",
118 | "INDEX",
119 | "LOGIN",
120 | "MASTER KEY",
121 | "MESSAGE TYPE",
122 | "PARTITION FUNCTION",
123 | "PARTITION SCHEME",
124 | "PROCEDURE",
125 | "QUEUE",
126 | "REMOTE SERVICE BINDING",
127 | "RESOURCE GOVERNOR",
128 | "RESOURCE POOL",
129 | "ROLE",
130 | "ROUTE",
131 | "SCHEMA",
132 | "SEARCH PROPERTY LIST",
133 | "SECURITY POLICY",
134 | "SEQUENCE",
135 | "SERVER AUDIT",
136 | "SERVER AUDIT SPECIFICATION",
137 | "SERVER CONFIGURATION",
138 | "SERVER ROLE",
139 | "SERVICE",
140 | "SERVICE MASTER KEY",
141 | "SYMMETRIC KEY",
142 | "TABLE",
143 | "TRIGGER",
144 | "USER",
145 | "VIEW",
146 | "WORKLOAD GROUP",
147 | "XML SCHEMA COLLECTION"
148 | ],
149 | "CROSS": [],
150 | "CURRENT": [],
151 | "CURRENT_DATE": [],
152 | "CURRENT_TIME": [],
153 | "CURRENT_TIMESTAMP": [],
154 | "CURRENT_USER": [],
155 | "CURSOR": [],
156 | "DATABASE": [],
157 | "DBCC": [],
158 | "DEALLOCATE": [],
159 | "DECLARE": [],
160 | "DEFAULT": [],
161 | "DELETE": [],
162 | "DENY": [],
163 | "DESC": [],
164 | "DISK": [],
165 | "DISTINCT": [],
166 | "DISTRIBUTED": [],
167 | "DOUBLE": [],
168 | "DROP": [
169 | "APPLICATION ROLE",
170 | "ASSEMBLY",
171 | "ASYMMETRIC KEY",
172 | "AUTHORIZATION",
173 | "AVAILABILITY GROUP",
174 | "BROKER PRIORITY",
175 | "CERTIFICATE",
176 | "COLUMN ENCRYPTION KEY",
177 | "CREDENTIAL",
178 | "CRYPTOGRAPHIC PROVIDER",
179 | "DATABASE",
180 | "DATABASE AUDIT SPECIFICATION",
181 | "DATABASE ENCRYPTION KEY",
182 | "DATABASE HADR",
183 | "DATABASE SCOPED CREDENTIAL",
184 | "DATABASE SCOPED CONFIGURATION",
185 | "DATABASE SET Options",
186 | "ENDPOINT",
187 | "EVENT SESSION",
188 | "EXTERNAL DATA SOURCE",
189 | "EXTERNAL LIBRARY",
190 | "EXTERNAL RESOURCE POOL",
191 | "FULLTEXT CATALOG",
192 | "FULLTEXT INDEX",
193 | "FULLTEXT STOPLIST",
194 | "FUNCTION",
195 | "INDEX",
196 | "LOGIN",
197 | "MASTER KEY",
198 | "MESSAGE TYPE",
199 | "PARTITION FUNCTION",
200 | "PARTITION SCHEME",
201 | "PROCEDURE",
202 | "QUEUE",
203 | "REMOTE SERVICE BINDING",
204 | "RESOURCE GOVERNOR",
205 | "RESOURCE POOL",
206 | "ROLE",
207 | "ROUTE",
208 | "SCHEMA",
209 | "SEARCH PROPERTY LIST",
210 | "SECURITY POLICY",
211 | "SEQUENCE",
212 | "SERVER AUDIT",
213 | "SERVER AUDIT SPECIFICATION",
214 | "SERVER CONFIGURATION",
215 | "SERVER ROLE",
216 | "SERVICE",
217 | "SERVICE MASTER KEY",
218 | "SYMMETRIC KEY",
219 | "TABLE",
220 | "TRIGGER",
221 | "USER",
222 | "VIEW",
223 | "WORKLOAD GROUP",
224 | "XML SCHEMA COLLECTION"
225 | ],
226 | "DUMP": [],
227 | "ELSE": [],
228 | "END": [],
229 | "ERRLVL": [],
230 | "ESCAPE": [],
231 | "EXCEPT": [],
232 | "EXEC": [],
233 | "EXECUTE": [],
234 | "EXISTS": [],
235 | "EXIT": [],
236 | "EXTERNAL": [],
237 | "FETCH": [],
238 | "FILE": [],
239 | "FILLFACTOR": [],
240 | "FOR": [],
241 | "FOREIGN": [],
242 | "FREETEXT": [],
243 | "FREETEXTTABLE": [],
244 | "FROM": [],
245 | "FULL": [],
246 | "FUNCTION": [],
247 | "GOTO": [],
248 | "GRANT": [],
249 | "GROUP": [],
250 | "GROUP BY": [],
251 | "HAVING": [],
252 | "HOLDLOCK": [],
253 | "IDENTITY": [],
254 | "IDENTITY_INSERT": [],
255 | "IDENTITYCOL": [],
256 | "IF": [],
257 | "IN": [],
258 | "INDEX": [],
259 | "INNER": [],
260 | "INSERT": [],
261 | "INTERSECT": [],
262 | "INTO": [],
263 | "IS": [],
264 | "JOIN": [],
265 | "KEY": [],
266 | "KILL": [],
267 | "LEFT": [],
268 | "LIKE": [],
269 | "LIMIT": [],
270 | "LINENO": [],
271 | "LOAD": [],
272 | "MERGE": [],
273 | "NATIONAL": [],
274 | "NOCHECK": [],
275 | "NONCLUSTERED": [],
276 | "NOT": [],
277 | "NULL": [],
278 | "NULLIF": [],
279 | "OF": [],
280 | "OFF": [],
281 | "OFFSETS": [],
282 | "ON": [],
283 | "OPEN": [],
284 | "OPENDATASOURCE": [],
285 | "OPENQUERY": [],
286 | "OPENROWSET": [],
287 | "OPENXML": [],
288 | "OPTION": [],
289 | "OR": [],
290 | "ORDER": [],
291 | "OUTER": [],
292 | "OVER": [],
293 | "PERCENT": [],
294 | "PIVOT": [],
295 | "PLAN": [],
296 | "PRECISION": [],
297 | "PRIMARY": [],
298 | "PRINT": [],
299 | "PROC": [],
300 | "PROCEDURE": [],
301 | "PUBLIC": [],
302 | "RAISERROR": [],
303 | "READ": [],
304 | "READTEXT": [],
305 | "RECONFIGURE": [],
306 | "REFERENCES": [],
307 | "REPLICATION": [],
308 | "RESTORE": [],
309 | "RESTRICT": [],
310 | "RETURN": [],
311 | "REVERT": [],
312 | "REVOKE": [],
313 | "RIGHT": [],
314 | "ROLLBACK": [],
315 | "ROWCOUNT": [],
316 | "ROWGUIDCOL": [],
317 | "RULE": [],
318 | "SAVE": [],
319 | "SCHEMA": [],
320 | "SECURITYAUDIT": [],
321 | "SELECT": [],
322 | "SEMANTICKEYPHRASETABLE": [],
323 | "SEMANTICSIMILARITYDETAILSTABLE": [],
324 | "SEMANTICSIMILARITYTABLE": [],
325 | "SESSION_USER": [],
326 | "SET": [],
327 | "SETUSER": [],
328 | "SHUTDOWN": [],
329 | "SOME": [],
330 | "STATISTICS": [],
331 | "SYSTEM_USER": [],
332 | "TABLE": [],
333 | "TABLESAMPLE": [],
334 | "TEXTSIZE": [],
335 | "THEN": [],
336 | "TO": [],
337 | "TOP": [],
338 | "TRAN": [],
339 | "TRANSFER": [],
340 | "TRANSACTION": [],
341 | "TRIGGER": [],
342 | "TRUNCATE": [],
343 | "TRY_CONVERT": [],
344 | "TSEQUAL": [],
345 | "UNION": [],
346 | "UNIQUE": [],
347 | "UNPIVOT": [],
348 | "UPDATE": [],
349 | "UPDATETEXT": [],
350 | "USE": [],
351 | "USER": [],
352 | "VALUES": [],
353 | "VARYING": [],
354 | "VIEW": [],
355 | "WAITFOR": [],
356 | "WHEN": [],
357 | "WHERE": [],
358 | "WHILE": [],
359 | "WITH": [],
360 | "WITHIN GROUP": [],
361 | "WRITETEXT": []
362 | },
363 | "functions": [
364 | "AVG",
365 | "CHECKSUM_AGG",
366 | "COUNT",
367 | "COUNT_BIG",
368 | "GROUPING",
369 | "GROUPING_ID",
370 | "MAX",
371 | "MIN",
372 | "STDEV",
373 | "STDEVP",
374 | "STRING_SPLIT",
375 | "SUM",
376 | "VAR",
377 | "VARP"
378 | ],
379 | "datatypes": [
380 | "BIGINT",
381 | "BINARY",
382 | "BIT",
383 | "CHAR",
384 | "CURSOR",
385 | "DATE",
386 | "DATETIME",
387 | "DATETIME2",
388 | "DATETIMEOFFSET",
389 | "DECIMAL",
390 | "FLOAT",
391 | "HIERARCHYID",
392 | "IMAGE",
393 | "INT",
394 | "MONEY",
395 | "NCHAR",
396 | "NTEXT",
397 | "NUMERIC",
398 | "NVARCHAR",
399 | "REAL",
400 | "ROWVERSION",
401 | "SMALLDATETIME",
402 | "SMALLINT",
403 | "SMALLMONEY",
404 | "SQL_VARIANT",
405 | "TABLE",
406 | "TEXT",
407 | "TIME",
408 | "TINYINT",
409 | "UNIQUEIDENTIFIER",
410 | "VARBINARY",
411 | "VARCHAR",
412 | "XML"
413 | ],
414 | "reserved": [
415 | ]
416 | }
417 |
--------------------------------------------------------------------------------
/odbcli/completion/parseutils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dbcli/odbc-cli/2b061f4d700067ee3ceaca24da86af5cd3e8b21f/odbcli/completion/parseutils/__init__.py
--------------------------------------------------------------------------------
/odbcli/completion/parseutils/ctes.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals
2 | from collections import namedtuple
3 | from sqlparse import parse
4 | from sqlparse.tokens import Keyword, CTE, DML
5 | from sqlparse.sql import Identifier, IdentifierList, Parenthesis
6 | from .meta import TableMetadata, ColumnMetadata
7 |
8 |
9 | # TableExpression is a namedtuple representing a CTE, used internally
10 | # name: cte alias assigned in the query
11 | # columns: list of column names
12 | # start: index into the original string of the left parens starting the CTE
13 | # stop: index into the original string of the right parens ending the CTE
14 | TableExpression = namedtuple('TableExpression', 'name columns start stop')
15 |
16 |
17 | def isolate_query_ctes(full_text, text_before_cursor):
18 | """Simplify a query by converting CTEs into table metadata objects
19 | """
20 |
21 | if not full_text:
22 | return full_text, text_before_cursor, tuple()
23 |
24 | ctes, _ = extract_ctes(full_text)
25 | if not ctes:
26 | return full_text, text_before_cursor, ()
27 |
28 | current_position = len(text_before_cursor)
29 | meta = []
30 |
31 | for cte in ctes:
32 | if cte.start < current_position < cte.stop:
33 | # Currently editing a cte - treat its body as the current full_text
34 | text_before_cursor = full_text[cte.start:current_position]
35 | full_text = full_text[cte.start:cte.stop]
36 | return full_text, text_before_cursor, meta
37 |
38 | # Append this cte to the list of available table metadata
39 | cols = (ColumnMetadata(name, None, ()) for name in cte.columns)
40 | meta.append(TableMetadata(cte.name, cols))
41 |
42 | # Editing past the last cte (ie the main body of the query)
43 | full_text = full_text[ctes[-1].stop:]
44 | text_before_cursor = text_before_cursor[ctes[-1].stop:current_position]
45 |
46 | return full_text, text_before_cursor, tuple(meta)
47 |
48 |
49 | def extract_ctes(sql):
50 | """ Extract constant table expresseions from a query
51 |
52 | Returns tuple (ctes, remainder_sql)
53 |
54 | ctes is a list of TableExpression namedtuples
55 | remainder_sql is the text from the original query after the CTEs have
56 | been stripped.
57 | """
58 |
59 | p = parse(sql)[0]
60 |
61 | # Make sure the first meaningful token is "WITH" which is necessary to
62 | # define CTEs
63 | idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True)
64 | if not (tok and tok.ttype == CTE):
65 | return [], sql
66 |
67 | # Get the next (meaningful) token, which should be the first CTE
68 | idx, tok = p.token_next(idx)
69 | if not tok:
70 | return ([], '')
71 | start_pos = token_start_pos(p.tokens, idx)
72 | ctes = []
73 |
74 | if isinstance(tok, IdentifierList):
75 | # Multiple ctes
76 | for t in tok.get_identifiers():
77 | cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t))
78 | cte = get_cte_from_token(t, start_pos + cte_start_offset)
79 | if not cte:
80 | continue
81 | ctes.append(cte)
82 | elif isinstance(tok, Identifier):
83 | # A single CTE
84 | cte = get_cte_from_token(tok, start_pos)
85 | if cte:
86 | ctes.append(cte)
87 |
88 | idx = p.token_index(tok) + 1
89 |
90 | # Collapse everything after the ctes into a remainder query
91 | remainder = u''.join(str(tok) for tok in p.tokens[idx:])
92 |
93 | return ctes, remainder
94 |
95 |
96 | def get_cte_from_token(tok, pos0):
97 | cte_name = tok.get_real_name()
98 | if not cte_name:
99 | return None
100 |
101 | # Find the start position of the opening parens enclosing the cte body
102 | idx, parens = tok.token_next_by(Parenthesis)
103 | if not parens:
104 | return None
105 |
106 | start_pos = pos0 + token_start_pos(tok.tokens, idx)
107 | cte_len = len(str(parens)) # includes parens
108 | stop_pos = start_pos + cte_len
109 |
110 | column_names = extract_column_names(parens)
111 |
112 | return TableExpression(cte_name, column_names, start_pos, stop_pos)
113 |
114 |
115 | def extract_column_names(parsed):
116 | # Find the first DML token to check if it's a SELECT or INSERT/UPDATE/DELETE
117 | idx, tok = parsed.token_next_by(t=DML)
118 | tok_val = tok and tok.value.lower()
119 |
120 | if tok_val in ('insert', 'update', 'delete'):
121 | # Jump ahead to the RETURNING clause where the list of column names is
122 | idx, tok = parsed.token_next_by(idx, (Keyword, 'returning'))
123 | elif not tok_val == 'select':
124 | # Must be invalid CTE
125 | return ()
126 |
127 | # The next token should be either a column name, or a list of column names
128 | idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True)
129 | return tuple(t.get_name() for t in _identifiers(tok))
130 |
131 |
132 | def token_start_pos(tokens, idx):
133 | return sum(len(str(t)) for t in tokens[:idx])
134 |
135 |
136 | def _identifiers(tok):
137 | if isinstance(tok, IdentifierList):
138 | for t in tok.get_identifiers():
139 | # NB: IdentifierList.get_identifiers() can return non-identifiers!
140 | if isinstance(t, Identifier):
141 | yield t
142 | elif isinstance(tok, Identifier):
143 | yield tok
144 |
145 |
--------------------------------------------------------------------------------
/odbcli/completion/parseutils/meta.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=too-many-instance-attributes
2 | # pylint: disable=too-many-arguments
3 |
4 | from __future__ import print_function, unicode_literals
5 | from collections import namedtuple
6 |
7 | _ColumnMetadata = namedtuple(
8 | 'ColumnMetadata',
9 | ['name', 'datatype', 'foreignkeys', 'default', 'has_default']
10 | )
11 |
12 |
13 | def ColumnMetadata(
14 | name, datatype, foreignkeys=None, default=None, has_default=False
15 | ):
16 | return _ColumnMetadata(
17 | name, datatype, foreignkeys or [], default, has_default
18 | )
19 |
20 |
21 | ForeignKey = namedtuple('ForeignKey', ['parentschema', 'parenttable',
22 | 'parentcolumn', 'childschema', 'childtable', 'childcolumn'])
23 | TableMetadata = namedtuple('TableMetadata', 'name columns')
24 |
25 |
26 | def parse_defaults(defaults_string):
27 | """
28 | Yields default values for a function, given the string provided.
29 | """
30 | if not defaults_string:
31 | return
32 | current = ''
33 | in_quote = None
34 | for char in defaults_string:
35 | if current == '' and char == ' ':
36 | # Skip space after comma separating default expressions
37 | continue
38 | if char in ('"', '\''):
39 | if in_quote and char == in_quote:
40 | # End quote
41 | in_quote = None
42 | elif not in_quote:
43 | # Begin quote
44 | in_quote = char
45 | elif char == ',' and not in_quote:
46 | # End of expression
47 | yield current
48 | current = ''
49 | continue
50 | current += char
51 | yield current
52 |
53 |
54 | class FunctionMetadata:
55 |
56 | def __init__(
57 | self, schema_name, func_name, arg_names, arg_types, arg_modes,
58 | return_type, is_aggregate, is_window, is_set_returning, arg_defaults
59 | ):
60 | """Class for describing a postgresql function"""
61 |
62 | self.schema_name = schema_name
63 | self.func_name = func_name
64 |
65 | self.arg_modes = tuple(arg_modes) if arg_modes else None
66 | self.arg_names = tuple(arg_names) if arg_names else None
67 |
68 | # Be flexible in not requiring arg_types -- use None as a placeholder
69 | # for each arg. (Used for compatibility with old versions of postgresql
70 | # where such info is hard to get.
71 | if arg_types:
72 | self.arg_types = tuple(arg_types)
73 | elif arg_modes:
74 | self.arg_types = tuple([None] * len(arg_modes))
75 | elif arg_names:
76 | self.arg_types = tuple([None] * len(arg_names))
77 | else:
78 | self.arg_types = None
79 |
80 | self.arg_defaults = tuple(parse_defaults(arg_defaults))
81 |
82 | self.return_type = return_type.strip()
83 | self.is_aggregate = is_aggregate
84 | self.is_window = is_window
85 | self.is_set_returning = is_set_returning
86 |
87 | def __eq__(self, other):
88 | return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
89 |
90 | def __ne__(self, other):
91 | return not self.__eq__(other)
92 |
93 | def _signature(self):
94 | return (
95 | self.schema_name, self.func_name, self.arg_names, self.arg_types,
96 | self.arg_modes, self.return_type, self.is_aggregate,
97 | self.is_window, self.is_set_returning, self.arg_defaults
98 | )
99 |
100 | def __hash__(self):
101 | return hash(self._signature())
102 |
103 | def __repr__(self):
104 | return (
105 | (
106 | '%s(schema_name=%r, func_name=%r, arg_names=%r, '
107 | 'arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, '
108 | 'is_window=%r, is_set_returning=%r, arg_defaults=%r)'
109 | ) % (self.__class__.__name__, self.schema_name, self.func_name, self.arg_names,
110 | self.arg_types, self.arg_modes, self.return_type, self.is_aggregate,
111 | self.is_window, self.is_set_returning, self.arg_defaults)
112 | )
113 |
114 | def has_variadic(self):
115 | return self.arg_modes and any(
116 | arg_mode == 'v' for arg_mode in self.arg_modes)
117 |
118 | def args(self):
119 | """Returns a list of input-parameter ColumnMetadata namedtuples."""
120 | if not self.arg_names:
121 | return []
122 | modes = self.arg_modes or ['i'] * len(self.arg_names)
123 | args = [
124 | (name, typ)
125 | for name, typ, mode in zip(self.arg_names, self.arg_types, modes)
126 | if mode in ('i', 'b', 'v') # IN, INOUT, VARIADIC
127 | ]
128 |
129 | def arg(name, typ, num):
130 | num_args = len(args)
131 | num_defaults = len(self.arg_defaults)
132 | has_default = num + num_defaults >= num_args
133 | default = (
134 | self.arg_defaults[num - num_args + num_defaults] if has_default
135 | else None
136 | )
137 | return ColumnMetadata(name, typ, [], default, has_default)
138 |
139 | return [arg(name, typ, num) for num, (name, typ) in enumerate(args)]
140 |
141 | def fields(self):
142 | """Returns a list of output-field ColumnMetadata namedtuples"""
143 |
144 | if self.return_type.lower() == 'void':
145 | return []
146 | if not self.arg_modes:
147 | # For functions without output parameters, the function name
148 | # is used as the name of the output column.
149 | # E.g. 'SELECT unnest FROM unnest(...);'
150 | return [ColumnMetadata(self.func_name, self.return_type, [])]
151 |
152 | return [ColumnMetadata(name, typ, [])
153 | for name, typ, mode in zip(
154 | self.arg_names, self.arg_types, self.arg_modes)
155 | if mode in ('o', 'b', 't')] # OUT, INOUT, TABLE
156 |
157 |
--------------------------------------------------------------------------------
/odbcli/completion/parseutils/tables.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from collections import namedtuple
3 | import sqlparse
4 | from sqlparse.sql import IdentifierList, Identifier, Function
5 | from sqlparse.tokens import Keyword, DML, Punctuation, Whitespace
6 |
7 | TableReference = namedtuple('TableReference', ['catalog', 'schema', 'name', 'alias',
8 | 'is_function'])
9 | TableReference.ref = property(lambda self: self.alias or (
10 | self.name if self.name.islower() or self.name[0] == '"'
11 | else '"' + self.name + '"'))
12 |
13 |
14 | # This code is borrowed from sqlparse example script.
15 | #
16 | def is_subselect(parsed):
17 | if not parsed.is_group:
18 | return False
19 | for item in parsed.tokens:
20 | if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT',
21 | 'UPDATE', 'CREATE', 'DELETE'):
22 | return True
23 | return False
24 |
25 |
26 | def _identifier_is_function(identifier):
27 | return any(isinstance(t, Function) for t in identifier.tokens)
28 |
29 |
30 | def extract_from_part(parsed, stop_at_punctuation=True):
31 | tbl_prefix_seen = False
32 | for item in parsed.tokens:
33 | if tbl_prefix_seen:
34 | if is_subselect(item):
35 | for x in extract_from_part(item, stop_at_punctuation):
36 | yield x
37 | elif stop_at_punctuation and item.ttype is Punctuation:
38 | # An incomplete nested select won't be recognized correctly as a
39 | # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
40 | # the second FROM to trigger this elif condition resulting in a
41 | # StopIteration. So we need to ignore the keyword if the keyword
42 | # FROM.
43 | # Also 'SELECT * FROM abc JOIN def' will trigger this elif
44 | # condition. So we need to ignore the keyword JOIN and its variants
45 | # INNER JOIN, FULL OUTER JOIN, etc.
46 | return
47 | elif item.ttype is Keyword and (
48 | not item.value.upper() == 'FROM') and \
49 | (not item.value.upper().endswith('JOIN')):
50 | tbl_prefix_seen = False
51 | else:
52 | yield item
53 | elif item.ttype is Keyword or item.ttype is Keyword.DML:
54 | item_val = item.value.upper()
55 | if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or
56 | item_val.endswith('JOIN')):
57 | tbl_prefix_seen = True
58 | # 'SELECT a, FROM abc' will detect FROM as part of the column list.
59 | # So this check here is necessary.
60 | elif isinstance(item, IdentifierList):
61 | for identifier in item.get_identifiers():
62 | if (identifier.ttype is Keyword and
63 | identifier.value.upper() == 'FROM'):
64 | tbl_prefix_seen = True
65 | break
66 |
67 |
68 | def extract_table_identifiers(token_stream, allow_functions=True):
69 | """yields tuples of TableReference namedtuples"""
70 |
71 | # We need to do some massaging of the names because postgres is case-
72 | # insensitive and '"Foo"' is not the same table as 'Foo' (while 'foo' is)
73 | def parse_identifier(item):
74 | alias = item.get_alias()
75 | sp_idx = item.token_next_by(t = Whitespace)[0] or len(item.tokens)
76 | item_rev = Identifier(list(reversed(item.tokens[:sp_idx])))
77 | name = item_rev._get_first_name(real_name = True)
78 | alias = alias or name
79 | dot_idx, _ = item_rev.token_next_by(m=(Punctuation, '.'))
80 | if dot_idx is not None:
81 | schema_name = item_rev._get_first_name(dot_idx, real_name = True)
82 | dot_idx, _ = item_rev.token_next_by(m=(Punctuation, '.'), idx = dot_idx)
83 | if dot_idx is not None:
84 | catalog_name = item_rev._get_first_name(dot_idx, real_name = True)
85 | else:
86 | catalog_name = None
87 | else:
88 | schema_name = None
89 | catalog_name = None
90 | # TODO: this business below needs help
91 | # for one we need to apply this logic to catalog_name
92 | # then the logic around name_quoted = quote_count > 2 obviously
93 | # doesn't work. Finally, quotechar needs to be customized
94 | schema_quoted = schema_name and item.value[0] == '"'
95 | if schema_name and not schema_quoted:
96 | schema_name = schema_name.lower()
97 | quote_count = item.value.count('"')
98 | name_quoted = quote_count > 2 or (quote_count and not schema_quoted)
99 | alias_quoted = alias and item.value[-1] == '"'
100 | if alias_quoted or name_quoted and not alias and name.islower():
101 | alias = '"' + (alias or name) + '"'
102 | if name and not name_quoted and not name.islower():
103 | if not alias:
104 | alias = name
105 | name = name.lower()
106 | return catalog_name, schema_name, name, alias
107 |
108 | for item in token_stream:
109 | if isinstance(item, IdentifierList):
110 | for identifier in item.get_identifiers():
111 | # Sometimes Keywords (such as FROM ) are classified as
112 | # identifiers which don't have the get_real_name() method.
113 | try:
114 | catalog_name, schema_name, real_name, alias = parse_identifier(identifier)
115 | is_function = allow_functions and _identifier_is_function(identifier)
116 | except AttributeError:
117 | continue
118 | if real_name:
119 | yield TableReference(catalog_name, schema_name, real_name,
120 | identifier.get_alias(), is_function)
121 | elif isinstance(item, Identifier):
122 | catalog_name, schema_name, real_name, alias = parse_identifier(item)
123 | is_function = allow_functions and _identifier_is_function(item)
124 |
125 | yield TableReference(catalog_name, schema_name, real_name, alias, is_function)
126 | elif isinstance(item, Function):
127 | catalog_name, schema_name, real_name, alias = parse_identifier(item)
128 | yield TableReference(None, None, real_name, alias, allow_functions)
129 |
130 |
131 | # extract_tables is inspired from examples in the sqlparse lib.
132 | def extract_tables(sql):
133 | """Extract the table names from an SQL statment.
134 |
135 | Returns a list of TableReference namedtuples
136 |
137 | """
138 | parsed = sqlparse.parse(sql)
139 | if not parsed:
140 | return ()
141 |
142 | # INSERT statements must stop looking for tables at the sign of first
143 | # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
144 | # abc is the table name, but if we don't stop at the first lparen, then
145 | # we'll identify abc, col1 and col2 as table names.
146 | insert_stmt = parsed[0].token_first().value.lower() == 'insert'
147 | stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
148 |
149 | # Kludge: sqlparse mistakenly identifies insert statements as
150 | # function calls due to the parenthesized column list, e.g. interprets
151 | # "insert into foo (bar, baz)" as a function call to foo with arguments
152 | # (bar, baz). So don't allow any identifiers in insert statements
153 | # to have is_function=True
154 | identifiers = extract_table_identifiers(stream,
155 | allow_functions=not insert_stmt)
156 | # In the case 'sche.', we get an empty TableReference; remove that
157 | return tuple(i for i in identifiers if i.name)
158 |
159 |
--------------------------------------------------------------------------------
/odbcli/completion/parseutils/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import re
3 | import sqlparse
4 | from sqlparse.sql import Identifier
5 | from sqlparse.tokens import Token, Error
6 |
7 | cleanup_regex = {
8 | # This matches only alphanumerics and underscores.
9 | 'alphanum_underscore': re.compile(r'(\w+)$'),
10 | # This matches everything except spaces, parens, colon, and comma
11 | 'many_punctuations': re.compile(r'([^():,\s]+)$'),
12 | # This matches everything except spaces, parens, colon, comma, and period
13 | 'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
14 | # This matches everything except a space.
15 | 'all_punctuations': re.compile(r'([^\s]+)$'),
16 | }
17 |
18 |
19 | def last_word(text, include='alphanum_underscore'):
20 | r"""
21 | Find the last word in a sentence.
22 |
23 | >>> last_word('abc')
24 | 'abc'
25 | >>> last_word(' abc')
26 | 'abc'
27 | >>> last_word('')
28 | ''
29 | >>> last_word(' ')
30 | ''
31 | >>> last_word('abc ')
32 | ''
33 | >>> last_word('abc def')
34 | 'def'
35 | >>> last_word('abc def ')
36 | ''
37 | >>> last_word('abc def;')
38 | ''
39 | >>> last_word('bac $def')
40 | 'def'
41 | >>> last_word('bac $def', include='most_punctuations')
42 | '$def'
43 | >>> last_word('bac \def', include='most_punctuations')
44 | '\\\\def'
45 | >>> last_word('bac \def;', include='most_punctuations')
46 | '\\\\def;'
47 | >>> last_word('bac::def', include='most_punctuations')
48 | 'def'
49 | >>> last_word('"foo*bar', include='most_punctuations')
50 | '"foo*bar'
51 | """
52 |
53 | if not text: # Empty string
54 | return ''
55 |
56 | if text[-1].isspace():
57 | return ''
58 | regex = cleanup_regex[include]
59 | matches = regex.search(text)
60 | if matches:
61 | return matches.group(0)
62 | return ''
63 |
64 |
65 | def find_prev_keyword(sql, n_skip=0):
66 | """ Find the last sql keyword in an SQL statement
67 |
68 | Returns the value of the last keyword, and the text of the query with
69 | everything after the last keyword stripped
70 | """
71 | if not sql.strip():
72 | return None, ''
73 |
74 | parsed = sqlparse.parse(sql)[0]
75 | flattened = list(parsed.flatten())
76 | flattened = flattened[:len(flattened) - n_skip]
77 |
78 | logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
79 |
80 | for t in reversed(flattened):
81 | if t.value == '(' or (t.is_keyword and
82 | (t.value.upper() not in logical_operators)
83 | ):
84 | # Find the location of token t in the original parsed statement
85 | # We can't use parsed.token_index(t) because t may be a child token
86 | # inside a TokenList, in which case token_index thows an error
87 | # Minimal example:
88 | # p = sqlparse.parse('select * from foo where bar')
89 | # t = list(p.flatten())[-3] # The "Where" token
90 | # p.token_index(t) # Throws ValueError: not in list
91 | idx = flattened.index(t)
92 |
93 | # Combine the string values of all tokens in the original list
94 | # up to and including the target keyword token t, to produce a
95 | # query string with everything after the keyword token removed
96 | text = ''.join(tok.value for tok in flattened[:idx + 1])
97 | return t, text
98 |
99 | return None, ''
100 |
101 |
102 | # Postgresql dollar quote signs look like `$$` or `$tag$`
103 | dollar_quote_regex = re.compile(r'^\$[^$]*\$$')
104 |
105 |
106 | def is_open_quote(sql):
107 | """Returns true if the query contains an unclosed quote"""
108 |
109 | # parsed can contain one or more semi-colon separated commands
110 | parsed = sqlparse.parse(sql)
111 | return any(_parsed_is_open_quote(p) for p in parsed)
112 |
113 |
114 | def _parsed_is_open_quote(parsed):
115 | # Look for unmatched single quotes, or unmatched dollar sign quotes
116 | return any(tok.match(Token.Error, ("'", '"', "$")) for tok in parsed.flatten())
117 |
118 |
119 | def parse_partial_identifier(word):
120 | """Attempt to parse a (partially typed) word as an identifier
121 |
122 | word may include a schema qualification, like `schema_name.partial_name`
123 | or `schema_name.` There may also be unclosed quotation marks, like
124 | `"schema`, or `schema."partial_name`
125 |
126 | :param word: string representing a (partially complete) identifier
127 | :return: sqlparse.sql.Identifier, or None
128 | """
129 |
130 | p = sqlparse.parse(word)[0]
131 | n_tok = len(p.tokens)
132 | if n_tok == 1 and isinstance(p.tokens[0], Identifier):
133 | return p.tokens[0]
134 | if p.token_next_by(m=(Error, '"'))[1]:
135 | # An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar'
136 | # Close the double quote, then reparse
137 | return parse_partial_identifier(word + '"')
138 | return None
139 |
140 |
--------------------------------------------------------------------------------
/odbcli/completion/prioritization.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals
2 |
3 | import re
4 | from collections import defaultdict
5 | import sqlparse
6 | from sqlparse.tokens import Name
7 | from .mssqlliterals.main import get_literals
8 |
9 |
10 | white_space_regex = re.compile('\\s+', re.MULTILINE)
11 |
12 |
13 | def _compile_regex(keyword):
14 | # Surround the keyword with word boundaries and replace interior whitespace
15 | # with whitespace wildcards
16 | pattern = '\\b' + white_space_regex.sub(r'\\s+', keyword) + '\\b'
17 | return re.compile(pattern, re.MULTILINE | re.IGNORECASE)
18 |
19 |
20 | keywords = get_literals('keywords')
21 | keyword_regexs = dict((kw, _compile_regex(kw)) for kw in keywords)
22 |
23 |
24 | class PrevalenceCounter:
25 | def __init__(self):
26 | self.keyword_counts = defaultdict(int)
27 | self.name_counts = defaultdict(int)
28 |
29 | def update(self, text):
30 | self.update_keywords(text)
31 | self.update_names(text)
32 |
33 | def update_names(self, text):
34 | for parsed in sqlparse.parse(text):
35 | for token in parsed.flatten():
36 | if token.ttype in Name:
37 | self.name_counts[token.value] += 1
38 |
39 | def clear_names(self):
40 | self.name_counts = defaultdict(int)
41 |
42 | def update_keywords(self, text):
43 | # Count keywords. Can't rely for sqlparse for this, because it's
44 | # database agnostic
45 | for keyword, regex in keyword_regexs.items():
46 | for _ in regex.finditer(text):
47 | self.keyword_counts[keyword] += 1
48 |
49 | def keyword_count(self, keyword):
50 | return self.keyword_counts[keyword]
51 |
52 | def name_count(self, name):
53 | return self.name_counts[name]
54 |
--------------------------------------------------------------------------------
/odbcli/completion/sqlcompletion.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=too-many-instance-attributes
2 |
3 | from __future__ import print_function, unicode_literals
4 | from collections import namedtuple
5 | import re
6 | import sqlparse
7 | from sqlparse.sql import Comparison, Identifier, Where
8 | from sqlparse.tokens import Keyword, DML, Punctuation
9 | from .parseutils.utils import (
10 | last_word, find_prev_keyword, parse_partial_identifier)
11 | from .parseutils.tables import extract_tables
12 | from .parseutils.ctes import isolate_query_ctes
13 |
14 | try:
15 | string_types = basestring # Python 2
16 | except NameError:
17 | string_types = str # Python 3
18 |
19 |
20 | Blank = namedtuple('Blank', [])
21 | Special = namedtuple('Special', [])
22 | NamedQuery = namedtuple('NamedQuery', [])
23 | Database = namedtuple('Database', [])
24 | Schema = namedtuple('Schema', 'parent quoted')
25 | Schema.__new__.__defaults__ = (None, False)
26 | # FromClauseItem is a table/view/function used in the FROM clause
27 | # `table_refs` contains the list of tables/... already in the statement,
28 | # used to ensure that the alias we suggest is unique
29 | FromClauseItem = namedtuple('FromClauseItem', 'grandparent parent table_refs local_tables')
30 | Table = namedtuple('Table', ['catalog', 'schema', 'table_refs', 'local_tables'])
31 | View = namedtuple('View', ['catalog', 'schema', 'table_refs'])
32 | # JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid'
33 | JoinCondition = namedtuple('JoinCondition', ['table_refs', 'parent'])
34 | # Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid'
35 | Join = namedtuple('Join', ['table_refs', 'schema', 'catalog'])
36 |
37 | Function = namedtuple('Function', ['catalog', 'schema', 'table_refs', 'usage'])
38 | # For convenience, don't require the `usage` argument in Function constructor
39 | Function.__new__.__defaults__ = (None, None, tuple(), None)
40 | Table.__new__.__defaults__ = (None, None, tuple(), tuple())
41 | View.__new__.__defaults__ = (None, None, tuple())
42 | FromClauseItem.__new__.__defaults__ = (None, None, tuple(), tuple())
43 |
44 | Column = namedtuple(
45 | 'Column',
46 | ['table_refs', 'require_last_table', 'local_tables', 'qualifiable', 'context']
47 | )
48 | Column.__new__.__defaults__ = (None, None, tuple(), False, None)
49 |
50 | Keyword = namedtuple('Keyword', ['last_token'])
51 | Keyword.__new__.__defaults__ = (None,)
52 | Datatype = namedtuple('Datatype', ['schema'])
53 | Alias = namedtuple('Alias', ['aliases'])
54 |
55 | Path = namedtuple('Path', [])
56 |
57 |
58 | class SqlStatement:
59 | def __init__(self, full_text, text_before_cursor):
60 | self.identifier = None
61 | self.word_before_cursor = word_before_cursor = last_word(
62 | text_before_cursor, include='many_punctuations')
63 | full_text = _strip_named_query(full_text)
64 | text_before_cursor = _strip_named_query(text_before_cursor)
65 |
66 | full_text, text_before_cursor, self.local_tables = \
67 | isolate_query_ctes(full_text, text_before_cursor)
68 |
69 | self.text_before_cursor_including_last_word = text_before_cursor
70 |
71 | # If we've partially typed a word then word_before_cursor won't be an
72 | # empty string. In that case we want to remove the partially typed
73 | # string before sending it to the sqlparser. Otherwise the last token
74 | # will always be the partially typed string which renders the smart
75 | # completion useless because it will always return the list of
76 | # keywords as completion.
77 | if self.word_before_cursor:
78 | if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\':
79 | parsed = sqlparse.parse(text_before_cursor)
80 | else:
81 | text_before_cursor = text_before_cursor[:-
82 | len(word_before_cursor)]
83 | parsed = sqlparse.parse(text_before_cursor)
84 | self.identifier = parse_partial_identifier(word_before_cursor)
85 | else:
86 | parsed = sqlparse.parse(text_before_cursor)
87 |
88 | full_text, text_before_cursor, parsed = \
89 | _split_multiple_statements(full_text, text_before_cursor, parsed)
90 |
91 | self.full_text = full_text
92 | self.text_before_cursor = text_before_cursor
93 | self.parsed = parsed
94 |
95 | self.last_token = parsed.token_prev(len(parsed.tokens))[1] \
96 | if parsed and parsed.token_prev(len(parsed.tokens))[1] else ''
97 |
98 | def is_insert(self):
99 | return self.parsed.token_first().value.lower() == 'insert'
100 |
101 | def get_tables(self, scope='full'):
102 | """ Gets the tables available in the statement.
103 | param `scope:` possible values: 'full', 'insert', 'before'
104 | If 'insert', only the first table is returned.
105 | If 'before', only tables before the cursor are returned.
106 | If not 'insert' and the stmt is an insert, the first table is skipped.
107 | """
108 | tables = extract_tables(
109 | self.full_text if scope == 'full' else self.text_before_cursor)
110 | if scope == 'insert':
111 | tables = tables[:1]
112 | elif self.is_insert():
113 | tables = tables[1:]
114 | return tables
115 |
116 | def get_previous_token(self, token):
117 | return self.parsed.token_prev(self.parsed.token_index(token))[1]
118 |
119 | def get_identifier_schema(self):
120 | schema = self.identifier.get_parent_name() \
121 | if (self.identifier and self.identifier.get_parent_name()) else None
122 | # If schema name is unquoted, lower-case it
123 | if schema and self.identifier.value[0] != '"':
124 | schema = schema.lower()
125 |
126 | return schema
127 |
128 | def get_identifier_parents(self):
129 | if self.identifier is None:
130 | return None, None
131 | item_rev = Identifier(list(reversed(self.identifier.tokens)))
132 | name = item_rev._get_first_name(real_name = True)
133 | dot_idx, _ = item_rev.token_next_by(m=(Punctuation, '.'))
134 | if dot_idx is not None:
135 | schema_name = item_rev._get_first_name(dot_idx, real_name = True)
136 | dot_idx, _ = item_rev.token_next_by(m=(Punctuation, '.'), idx = dot_idx)
137 | if dot_idx is not None:
138 | catalog_name = item_rev._get_first_name(dot_idx, real_name = True)
139 | else:
140 | catalog_name = None
141 | else:
142 | schema_name = None
143 | catalog_name = None
144 | return catalog_name, schema_name
145 |
146 | def reduce_to_prev_keyword(self, n_skip=0):
147 | prev_keyword, self.text_before_cursor = \
148 | find_prev_keyword(self.text_before_cursor, n_skip=n_skip)
149 | return prev_keyword
150 |
151 |
152 | def suggest_type(full_text, text_before_cursor):
153 | """Takes the full_text that is typed so far and also the text before the
154 | cursor to suggest completion type and scope.
155 |
156 | Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
157 | A scope for a column category will be a list of tables.
158 | """
159 |
160 | if full_text.startswith('\\i '):
161 | return (Path(),)
162 |
163 | # This is a temporary hack; the exception handling
164 | # here should be removed once sqlparse has been fixed
165 | try:
166 | stmt = SqlStatement(full_text, text_before_cursor)
167 | except (TypeError, AttributeError):
168 | return []
169 |
170 | # # Check for special commands and handle those separately
171 | # if stmt.parsed:
172 | # # Be careful here because trivial whitespace is parsed as a
173 | # # statement, but the statement won't have a first token
174 | # tok1 = stmt.parsed.token_first()
175 | # if tok1 and tok1.value == '\\':
176 | # text = stmt.text_before_cursor + stmt.word_before_cursor
177 | # return suggest_special(text)
178 |
179 | return suggest_based_on_last_token(stmt.last_token, stmt)
180 |
181 |
182 | named_query_regex = re.compile(r'^\s*\\sn\s+[A-z0-9\-_]+\s+')
183 |
184 |
185 | def _strip_named_query(txt):
186 | """
187 | This will strip "save named query" command in the beginning of the line:
188 | '\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
189 | ' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
190 | """
191 |
192 | if named_query_regex.match(txt):
193 | txt = named_query_regex.sub('', txt)
194 | return txt
195 |
196 |
197 | function_body_pattern = re.compile(r'(\$.*?\$)([\s\S]*?)\1', re.M)
198 |
199 |
200 | def _find_function_body(text):
201 | split = function_body_pattern.search(text)
202 | return (split.start(2), split.end(2)) if split else (None, None)
203 |
204 |
205 | def _statement_from_function(full_text, text_before_cursor, statement):
206 | current_pos = len(text_before_cursor)
207 | body_start, body_end = _find_function_body(full_text)
208 | if body_start is None:
209 | return full_text, text_before_cursor, statement
210 | if not body_start <= current_pos < body_end:
211 | return full_text, text_before_cursor, statement
212 | full_text = full_text[body_start:body_end]
213 | text_before_cursor = text_before_cursor[body_start:]
214 | parsed = sqlparse.parse(text_before_cursor)
215 | return _split_multiple_statements(full_text, text_before_cursor, parsed)
216 |
217 |
218 | def _split_multiple_statements(full_text, text_before_cursor, parsed):
219 | if len(parsed) > 1:
220 | # Multiple statements being edited -- isolate the current one by
221 | # cumulatively summing statement lengths to find the one that bounds
222 | # the current position
223 | current_pos = len(text_before_cursor)
224 | stmt_start, stmt_end = 0, 0
225 |
226 | for statement in parsed:
227 | stmt_len = len(str(statement))
228 | stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
229 |
230 | if stmt_end >= current_pos:
231 | text_before_cursor = full_text[stmt_start:current_pos]
232 | full_text = full_text[stmt_start:]
233 | break
234 |
235 | elif parsed:
236 | # A single statement
237 | statement = parsed[0]
238 | else:
239 | # The empty string
240 | return full_text, text_before_cursor, None
241 |
242 | token2 = None
243 | if statement.get_type() in ('CREATE', 'CREATE OR REPLACE'):
244 | token1 = statement.token_first()
245 | if token1:
246 | token1_idx = statement.token_index(token1)
247 | token2 = statement.token_next(token1_idx)[1]
248 | if token2 and token2.value.upper() == 'FUNCTION':
249 | full_text, text_before_cursor, statement = _statement_from_function(
250 | full_text, text_before_cursor, statement
251 | )
252 | return full_text, text_before_cursor, statement
253 |
254 |
255 | SPECIALS_SUGGESTION = {
256 | 'lf': Function,
257 | 'lt': Table,
258 | 'lv': View,
259 | 'sf': Function,
260 | }
261 |
262 |
263 | #def suggest_special(text):
264 | # text = text.lstrip()
265 | # cmd, _, arg = parse_special_command(text)
266 | #
267 | # if cmd == text:
268 | # # Trying to complete the special command itself
269 | # return (Special(),)
270 | #
271 | # if cmd in ('\\c', '\\connect'):
272 | # return (Database(),)
273 | #
274 | # if cmd == '\\ls':
275 | # return (Schema(),)
276 | #
277 | # if arg:
278 | # # Try to distinguish "\d name" from "\d schema.name"
279 | # # Note that this will fail to obtain a schema name if wildcards are
280 | # # used, e.g. "\d schema???.name"
281 | # parsed = sqlparse.parse(arg)[0].tokens[0]
282 | # try:
283 | # schema = parsed.get_parent_name()
284 | # except AttributeError:
285 | # schema = None
286 | # else:
287 | # schema = None
288 | #
289 | # if cmd[1:] == 'd':
290 | # # \d can describe tables or views
291 | # if schema:
292 | # return (Table(schema=schema),
293 | # View(schema=schema),)
294 | # return (Schema(),
295 | # Table(schema=None),
296 | # View(schema=None),)
297 | # if cmd[1:] in SPECIALS_SUGGESTION:
298 | # rel_type = SPECIALS_SUGGESTION[cmd[1:]]
299 | # if schema:
300 | # if rel_type == Function:
301 | # return (Function(schema=schema, usage='special'),)
302 | # return (rel_type(schema=schema),)
303 | # if rel_type == Function:
304 | # return (Schema(), Function(schema=None, usage='special'),)
305 | # return (Schema(), rel_type(schema=None))
306 | #
307 | # if cmd in ['\\n', '\\sn', '\\dn']:
308 | # return (NamedQuery(),)
309 | #
310 | # return (Keyword(), Special())
311 |
312 |
313 | def suggest_based_on_last_token(token, stmt):
314 |
315 | if isinstance(token, string_types):
316 | token_v = token.lower()
317 | elif isinstance(token, Comparison):
318 | # If 'token' is a Comparison type such as
319 | # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
320 | # token.value on the comparison type will only return the lhs of the
321 | # comparison. In this case a.id. So we need to do token.tokens to get
322 | # both sides of the comparison and pick the last token out of that
323 | # list.
324 | token_v = token.tokens[-1].value.lower()
325 | elif isinstance(token, Where):
326 | # sqlparse groups all tokens from the where clause into a single token
327 | # list. This means that token.value may be something like
328 | # 'where foo > 5 and '. We need to look "inside" token.tokens to handle
329 | # suggestions in complicated where clauses correctly
330 | prev_keyword = stmt.reduce_to_prev_keyword()
331 | return suggest_based_on_last_token(prev_keyword, stmt)
332 | elif isinstance(token, Identifier):
333 | # If the previous token is an identifier, we can suggest datatypes if
334 | # we're in a parenthesized column/field list, e.g.:
335 | # CREATE TABLE foo (Identifier
336 | # CREATE FUNCTION foo (Identifier
337 | # If we're not in a parenthesized list, the most likely scenario is the
338 | # user is about to specify an alias, e.g.:
339 | # SELECT Identifier
340 | # SELECT foo FROM Identifier
341 | prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor)
342 | if prev_keyword and prev_keyword.value == '(':
343 | # Suggest datatypes
344 | return suggest_based_on_last_token('type', stmt)
345 | return (Keyword(),)
346 | else:
347 | token_v = token.value.lower()
348 |
349 | if not token:
350 | return (Keyword(), Special())
351 | if token_v.endswith('('):
352 | p = sqlparse.parse(stmt.text_before_cursor)[0]
353 |
354 | if p.tokens and isinstance(p.tokens[-1], Where):
355 | # Four possibilities:
356 | # 1 - Parenthesized clause like "WHERE foo AND ("
357 | # Suggest columns/functions
358 | # 2 - Function call like "WHERE foo("
359 | # Suggest columns/functions
360 | # 3 - Subquery expression like "WHERE EXISTS ("
361 | # Suggest keywords, in order to do a subquery
362 | # 4 - Subquery OR array comparison like "WHERE foo = ANY("
363 | # Suggest columns/functions AND keywords. (If we wanted to be
364 | # really fancy, we could suggest only array-typed columns)
365 |
366 | column_suggestions = suggest_based_on_last_token('where', stmt)
367 |
368 | # Check for a subquery expression (cases 3 & 4)
369 | where = p.tokens[-1]
370 | prev_tok = where.token_prev(len(where.tokens) - 1)[1]
371 |
372 | if isinstance(prev_tok, Comparison):
373 | # e.g. "SELECT foo FROM bar WHERE foo = ANY("
374 | prev_tok = prev_tok.tokens[-1]
375 |
376 | prev_tok = prev_tok.value.lower()
377 | if prev_tok == 'exists':
378 | return (Keyword(),)
379 | return column_suggestions
380 |
381 | # Get the token before the parens
382 | prev_tok = p.token_prev(len(p.tokens) - 1)[1]
383 |
384 | if (prev_tok and prev_tok.value and prev_tok.value.lower().split(' ')[-1] == 'using'):
385 | # tbl1 INNER JOIN tbl2 USING (col1, col2)
386 | tables = stmt.get_tables('before')
387 |
388 | # suggest columns that are present in more than one table
389 | return (Column(table_refs=tables,
390 | require_last_table=True,
391 | local_tables=stmt.local_tables),)
392 |
393 | if p.token_first().value.lower() == 'select':
394 | # If the lparen is preceeded by a space chances are we're about to
395 | # do a sub-select.
396 | if last_word(stmt.text_before_cursor,
397 | 'all_punctuations').startswith('('):
398 | return (Keyword(),)
399 | prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1]
400 | if prev_prev_tok and prev_prev_tok.normalized == 'INTO':
401 | return (
402 | Column(table_refs=stmt.get_tables('insert'), context='insert'),
403 | )
404 | # We're probably in a function argument list
405 | return (Column(table_refs=extract_tables(stmt.full_text),
406 | local_tables=stmt.local_tables, qualifiable=True),)
407 | if token_v == 'set':
408 | return (Column(table_refs=stmt.get_tables(),
409 | local_tables=stmt.local_tables),)
410 | if token_v in ('select', 'where', 'having', 'order by', 'distinct'):
411 | # Check for a table alias or schema qualification
412 | parent = stmt.identifier.get_parent_name() \
413 | if (stmt.identifier and stmt.identifier.get_parent_name()) else []
414 | tables = stmt.get_tables()
415 | if parent:
416 | tables = tuple(t for t in tables if identifies(parent, t))
417 | return (Column(table_refs=tables, local_tables=stmt.local_tables),
418 | Table(schema=parent),
419 | View(schema=parent),
420 | Function(schema=parent),)
421 | return (Column(table_refs=tables, local_tables=stmt.local_tables,
422 | qualifiable=True),
423 | Function(schema=None),
424 | Keyword(token_v.upper()),)
425 | if token_v == 'as':
426 | # Don't suggest anything for aliases
427 | return ()
428 | if (token_v.endswith('join') and token.is_keyword) or \
429 | token_v in ('copy', 'from', 'update', 'into', 'describe', 'truncate'):
430 |
431 | catalog, schema = stmt.get_identifier_parents()
432 | tables = extract_tables(stmt.text_before_cursor)
433 | is_join = token_v.endswith('join') and token.is_keyword
434 |
435 | # Suggest tables from either the currently-selected schema or the
436 | # public schema if no schema has been specified
437 | suggest = []
438 |
439 | if catalog is None and schema is None:
440 | suggest.insert(0, Database())
441 | suggest.append(Schema())
442 | elif not catalog:
443 | suggest.insert(0, Schema(parent = schema))
444 | if token_v == 'from' or is_join:
445 | suggest.append(FromClauseItem(grandparent = catalog,
446 | parent = schema,
447 | table_refs=tables,
448 | local_tables=stmt.local_tables))
449 | elif token_v == 'truncate':
450 | suggest.append(Table(catalog, schema))
451 | else:
452 | suggest.extend((Table(catalog, schema), View(catalog, schema)))
453 |
454 | # TODO: Join(catalog = catalog, ...)
455 | if is_join and _allow_join(stmt.parsed):
456 | tables = stmt.get_tables('before')
457 | suggest.append(Join(
458 | table_refs = tables,
459 | schema=schema,
460 | catalog = catalog))
461 |
462 | return tuple(suggest)
463 |
464 | # TODO: Use get_identifier_parents
465 | if token_v == 'function':
466 | schema = stmt.get_identifier_schema()
467 | # stmt.get_previous_token will fail for e.g. `SELECT 1 FROM functions
468 | # WHERE function:`
469 | try:
470 | prev = stmt.get_previous_token(token).value.lower()
471 | if prev in('drop', 'alter', 'create', 'create or replace'):
472 | return (Function(schema=schema, usage='signature'),)
473 | except ValueError:
474 | pass
475 | return tuple()
476 |
477 | if token_v in ('table', 'view'):
478 | # E.g. 'ALTER TABLE '
479 | rel_type = {
480 | 'table': Table,
481 | 'view': View,
482 | 'function': Function}[token_v]
483 | catalog, schema = stmt.get_identifier_parents()
484 | suggest = []
485 | if catalog is None and schema is None:
486 | suggest.insert(0, Database())
487 | suggest.append(Schema())
488 | elif not catalog:
489 | suggest.insert(0, Schema(parent = schema))
490 | suggest.append(rel_type(catalog = catalog, schema = schema))
491 |
492 | return suggest
493 |
494 | if token_v == 'column':
495 | # E.g. 'ALTER TABLE foo ALTER COLUMN bar
496 | return (Column(table_refs=stmt.get_tables()),)
497 |
498 | if token_v == 'on':
499 | tables = stmt.get_tables('before')
500 | parent = stmt.identifier.get_parent_name() \
501 | if (stmt.identifier and stmt.identifier.get_parent_name()) else None
502 | if parent:
503 | # "ON parent."
504 | # parent can be either a schema name or table alias
505 | filteredtables = tuple(t for t in tables if identifies(parent, t))
506 | sugs = [Column(table_refs=filteredtables,
507 | local_tables=stmt.local_tables),
508 | Table(schema=parent),
509 | View(schema=parent),
510 | Function(schema=parent)]
511 | if filteredtables and _allow_join_condition(stmt.parsed):
512 | sugs.append(JoinCondition(table_refs=tables,
513 | parent=filteredtables[-1]))
514 | return tuple(sugs)
515 | # ON
516 | # Use table alias if there is one, otherwise the table name
517 | aliases = tuple(t.ref for t in tables)
518 | if _allow_join_condition(stmt.parsed):
519 | return (Alias(aliases=aliases), JoinCondition(
520 | table_refs=tables, parent=None))
521 | return (Alias(aliases=aliases),)
522 |
523 | if token_v in ('c', 'use', 'database', 'template'):
524 | # "\c ", "DROP DATABASE ",
525 | # "CREATE DATABASE WITH TEMPLATE "
526 | return (Database(),)
527 | if token_v == 'schema':
528 | # DROP SCHEMA schema_name, SET SCHEMA schema name
529 | prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2)
530 | quoted = prev_keyword and prev_keyword.value.lower() == 'set'
531 | return (Schema(quoted),)
532 | if token_v.endswith(',') or token_v in ('=', 'and', 'or'):
533 | prev_keyword = stmt.reduce_to_prev_keyword()
534 | if prev_keyword:
535 | return suggest_based_on_last_token(prev_keyword, stmt)
536 | return ()
537 | if token_v in ('type', '::'):
538 | # ALTER TABLE foo SET DATA TYPE bar
539 | # SELECT foo::bar
540 | # Note that tables are a form of composite type in postgresql, so
541 | # they're suggested here as well
542 | # TODO: Use get_identifier_parents
543 | schema = stmt.get_identifier_schema()
544 | suggestions = [Datatype(schema=schema),
545 | Table(schema=schema)]
546 | if not schema:
547 | suggestions.append(Schema())
548 | return tuple(suggestions)
549 | if token_v in {'alter', 'create', 'drop'}:
550 | return (Keyword(token_v.upper()),)
551 | if token_v in {'limit'}:
552 | return (Blank(),)
553 | if token.is_keyword:
554 | # token is a keyword we haven't implemented any special handling for
555 | # go backwards in the query until we find one we do recognize
556 | prev_keyword = stmt.reduce_to_prev_keyword(n_skip=1)
557 | if prev_keyword:
558 | return suggest_based_on_last_token(prev_keyword, stmt)
559 | return (Keyword(token_v.upper()),)
560 | return (Keyword(),)
561 |
562 |
563 | def identifies(string_id, ref):
564 | """Returns true if string `string_id` matches TableReference `ref`"""
565 |
566 | return string_id == ref.alias or string_id == ref.name or (
567 | ref.schema and (string_id == ref.schema + '.' + ref.name))
568 |
569 |
570 | def _allow_join_condition(statement):
571 | """
572 | Tests if a join condition should be suggested
573 |
574 | We need this to avoid bad suggestions when entering e.g.
575 | select * from tbl1 a join tbl2 b on a.id =
576 | So check that the preceding token is a ON, AND, or OR keyword, instead of
577 | e.g. an equals sign.
578 |
579 | :param statement: an sqlparse.sql.Statement
580 | :return: boolean
581 | """
582 |
583 | if not statement or not statement.tokens:
584 | return False
585 |
586 | last_tok = statement.token_prev(len(statement.tokens))[1]
587 | return last_tok.value.lower() in ('on', 'and', 'or')
588 |
589 |
590 | def _allow_join(statement):
591 | """
592 | Tests if a join should be suggested
593 |
594 | We need this to avoid bad suggestions when entering e.g.
595 | select * from tbl1 a join tbl2 b
596 | So check that the preceding token is a JOIN keyword
597 |
598 | :param statement: an sqlparse.sql.Statement
599 | :return: boolean
600 | """
601 |
602 | if not statement or not statement.tokens:
603 | return False
604 |
605 | last_tok = statement.token_prev(len(statement.tokens))[1]
606 | return (last_tok.value.lower().endswith('join') and
607 | last_tok.value.lower() not in('cross join', 'natural join'))
608 |
609 |
--------------------------------------------------------------------------------
/odbcli/config.py:
--------------------------------------------------------------------------------
1 | import errno
2 | import shutil
3 | import os
4 | import platform
5 | from os.path import expanduser, exists, dirname
6 | from configobj import ConfigObj
7 |
8 |
9 | def config_location():
10 | if "XDG_CONFIG_HOME" in os.environ:
11 | return "%s/odbcli/" % expanduser(os.environ["XDG_CONFIG_HOME"])
12 | elif platform.system() == "Windows":
13 | return os.getenv("USERPROFILE") + "\\AppData\\Local\\dbcli\\odbcli\\"
14 | else:
15 | return expanduser("~/.config/odbcli/")
16 |
17 |
18 | def load_config(usr_cfg, def_cfg=None):
19 | cfg = ConfigObj()
20 | cfg.merge(ConfigObj(def_cfg, interpolation=False))
21 | cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8"))
22 | cfg.filename = expanduser(usr_cfg)
23 |
24 | return cfg
25 |
26 |
27 | def ensure_dir_exists(path):
28 | parent_dir = expanduser(dirname(path))
29 | os.makedirs(parent_dir, exist_ok=True)
30 |
31 |
32 | def write_default_config(source, destination, overwrite=False):
33 | destination = expanduser(destination)
34 | if not overwrite and exists(destination):
35 | return
36 |
37 | ensure_dir_exists(destination)
38 |
39 | shutil.copyfile(source, destination)
40 |
41 |
42 | def upgrade_config(config, def_config):
43 | cfg = load_config(config, def_config)
44 | cfg.write()
45 |
46 |
47 | def get_config(odbclirc_file=None):
48 | from odbcli import __file__ as package_root
49 |
50 | package_root = os.path.dirname(package_root)
51 |
52 | odbclirc_file = odbclirc_file or "%sconfig" % config_location()
53 |
54 | default_config = os.path.join(package_root, "odbclirc")
55 | write_default_config(default_config, odbclirc_file)
56 |
57 | return load_config(odbclirc_file, default_config)
58 |
--------------------------------------------------------------------------------
/odbcli/conn.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from cyanodbc import connect, Connection, SQLGetInfo, Cursor, DatabaseError, ConnectError
3 | from typing import Optional
4 | from cli_helpers.tabular_output import TabularOutputFormatter
5 | from logging import getLogger
6 | from re import sub
7 | from threading import Lock, Event, Thread, Condition
8 | from asyncio import get_event_loop
9 | from enum import IntEnum
10 | from .dbmetadata import DbMetadata
11 |
12 | formatter = TabularOutputFormatter()
13 |
14 | class connStatus(Enum):
15 | DISCONNECTED = 0
16 | IDLE = 1
17 | EXECUTING = 2
18 | FETCHING = 3
19 | ERROR = 4
20 |
21 | class executionStatus(IntEnum):
22 | OK = 0
23 | FAIL = 1
24 | OKWRESULTS = 2
25 |
26 | class sqlConnection:
27 | def __init__(
28 | self,
29 | dsn: str,
30 | conn: Optional[Connection] = Connection(),
31 | username: Optional[str] = "",
32 | password: Optional[str] = ""
33 | ) -> None:
34 | self.dsn = dsn
35 | self.conn = conn
36 | self.cursor: Cursor = None
37 | self.query: str = None
38 | self.username = username
39 | self.password = password
40 | self.logger = getLogger(__name__)
41 | self.dbmetadata = DbMetadata()
42 | self._quotechar = None
43 | self._search_escapechar = None
44 | self._search_escapepattern = None
45 | # Lock to be held by database interaction that happens
46 | # in the main process. Recall, main-buffer as well as preview
47 | # buffer queries get executed in a separate process, however
48 | # auto-completion, as well as object browser expansion happen
49 | # in the main process possibly multi-threaded. Multi threaded is fine
50 | # we don't want the main process to lock-up while writing a query,
51 | # however, we don't want to potentially hammer the connection with
52 | # multiple auto-completion result queries before each has had a chance
53 | # to return.
54 | self._lock = Lock()
55 |
56 | # Lock to be held when updating self.status, which can happen from
57 | # a thread
58 | self._status_lock = Lock()
59 | # Lock to be held when updating self._execution_status_lock which
60 | # can happen from a thread
61 | self._execution_status_lock = Lock()
62 |
63 | # Lock that protects interaction with _fetch_res
64 | self._fetch_cv = Condition()
65 | # This is the list that carries the cache of retrieved rows via the
66 | # asynchronous fetch operation
67 | self._fetch_res: list = []
68 | self._fetch_thread = Thread()
69 | self._execution_thread = Thread()
70 | self._cancel_async_event = Event()
71 |
72 | self._status = connStatus.DISCONNECTED
73 | self._execution_status: executionStatus = executionStatus.OK
74 | self._execution_err: str = None
75 |
76 | @property
77 | def execution_status(self) -> executionStatus:
78 | """ Hold the lock here since it gets assigned in execute
79 | which can be called in a different thread """
80 | with self._execution_status_lock:
81 | res = self._execution_status
82 | return res
83 |
84 | @property
85 | def status(self) -> connStatus:
86 | """ Hold the lock here since it can be assigned in more than one
87 | thread """
88 | with self._status_lock:
89 | res = self._status
90 | return res
91 |
92 | @property
93 | def execution_err(self) -> str:
94 | """ Last execution error: Cleared prior to every execution.
95 | Hold the lock here since it gets assigned in execute
96 | which can be called in a different thread """
97 | with self._lock:
98 | res = self._execution_err
99 | return res
100 |
101 | @property
102 | def quotechar(self) -> str:
103 | if self._quotechar is None:
104 | self._quotechar = self.conn.get_info(
105 | SQLGetInfo.SQL_IDENTIFIER_QUOTE_CHAR)
106 | # pyodbc note
107 | # self._quotechar = self.conn.getinfo(
108 | return self._quotechar
109 |
110 | @property
111 | def search_escapechar(self) -> str:
112 | if self._search_escapechar is None:
113 | self._search_escapechar = self.conn.get_info(
114 | SQLGetInfo.SQL_SEARCH_PATTERN_ESCAPE)
115 | return self._search_escapechar
116 |
117 | @property
118 | def search_escapepattern(self) -> str:
119 | if self._search_escapepattern is None:
120 | # https://stackoverflow.com/questions/2428117/casting-raw-strings-python
121 | self._search_escapepattern = \
122 | (self.search_escapechar).encode("unicode-escape").decode()
123 |
124 | return self._search_escapepattern
125 |
126 | def update_status(self, status: connStatus = connStatus.IDLE) -> None:
127 | """ Thread safe way of updating the connection status
128 | """
129 | with self._status_lock:
130 | self._status = status
131 |
132 | def update_execution_status(self, status: executionStatus = executionStatus.OK) -> None:
133 | """ Thread safe way of updating the execution status
134 | """
135 | with self._execution_status_lock:
136 | self._execution_status = status
137 |
138 | def sanitize_search_string(self, term) -> str:
139 | if term is not None and len(term):
140 | res = sub("(_|%)", self.search_escapepattern + "\\1", term)
141 | else:
142 | res = term
143 | return res
144 |
145 | def unsanitize_search_string(self, term) -> str:
146 | if term is not None and len(term):
147 | res = sub(self.search_escapepattern, "", term)
148 | else:
149 | res = term
150 | return res
151 |
152 | def escape_name(self, name):
153 | if name:
154 | qtchar = self.quotechar
155 | name = (qtchar + "%s" + qtchar) % name
156 | return name
157 |
158 | def escape_names(self, names):
159 | return [self.escape_name(name) for name in names]
160 |
161 | def unescape_name(self, name):
162 | """ Unquote a string."""
163 | if name:
164 | qtchar = self.quotechar
165 | if name and name[0] == qtchar and name[-1] == qtchar:
166 | name = name[1:-1]
167 |
168 | return name
169 |
170 | def connect(
171 | self,
172 | username: str = "",
173 | password: str = "",
174 | force: bool = False) -> None:
175 | uid = username or self.username
176 | pwd = password or self.password
177 | conn_str = "DSN=" + self.dsn + ";"
178 | if len(uid):
179 | self.username = uid
180 | conn_str = conn_str + "UID=" + uid + ";"
181 | if len(pwd):
182 | self.password = pwd
183 | conn_str = conn_str + "PWD=" + pwd + ";"
184 | if force or not self.conn.connected():
185 | try:
186 | self.conn = connect(dsn = conn_str, timeout = 5)
187 | self.update_status(connStatus.IDLE)
188 | except ConnectError as e:
189 | self.logger.error("Error while connecting: %s", str(e))
190 | raise ConnectError(e)
191 |
192 | def fetchmany(self, size) -> list:
193 | with self._lock:
194 | if self.cursor:
195 | # This gets called in a thread / so exceptions can get lost
196 | # Make sure to recover after cyanodbc errors so that we can
197 | # complete the tail part of the worker (status/notify)
198 | try:
199 | res = self.cursor.fetchmany(size)
200 | except DatabaseError as e:
201 | self.logger.warning("Error while fetching: %s", str(e))
202 | res = []
203 | else:
204 | res = []
205 |
206 | if len(res) < 1:
207 | self.update_status(connStatus.IDLE)
208 | with self._fetch_cv:
209 | self._fetch_res.extend(res)
210 | self._fetch_cv.notify()
211 |
212 | return res
213 |
214 | def async_fetchall(self, size, app) -> None:
215 | """ True asynchronous fetch. Will start a fetch, that will fetch
216 | *all* results (in chunks of size = size) in a background operation
217 | until the result set is depleted or we signal a stop via
218 | _cancel_async_event. After thread operation is completed,
219 | it asks the running event loop to redraw the app to pick up the
220 | new connection status (IDLE).
221 | """
222 | self._cancel_async_event.clear()
223 | self.update_status(connStatus.FETCHING)
224 | loop = get_event_loop()
225 | def _run():
226 | while True:
227 | res = self.fetchmany(size)
228 | if len(res) < 1 or self._cancel_async_event.is_set():
229 | self.update_status(connStatus.IDLE)
230 | self._cancel_async_event.clear()
231 | # Should we try close cursor here? Problem is that
232 | # close_curor attempts to call cancel async which would
233 | # block until this thread is over
234 | break
235 | loop.call_soon_threadsafe(app.invalidate)
236 | return
237 |
238 | self._fetch_thread = Thread(target = _run, daemon = True)
239 | self._fetch_thread.start()
240 | return
241 |
242 | def fetch_from_cache(self, size, wait = False) -> list:
243 | """ Will grab the first size elements from self._fetch_res. Recall
244 | self._fetch_res is the result cache that is built up via an async
245 | fetch that grabs rows in chunks. Here, in a threadsafe manner we
246 | wait for a the asynchronous method to grab enough elements or
247 | finish the fetch operation altogether, then we 'pop' from the
248 | fetch result cache.
249 | """
250 | with self._fetch_cv:
251 | if wait:
252 | self._fetch_cv.wait_for(
253 | lambda: len(self._fetch_res) > size or self.status == connStatus.IDLE)
254 | res = self._fetch_res[:size]
255 | del self._fetch_res[:size]
256 | return res
257 |
258 |
259 | def cancel_async_fetchall(self) -> None:
260 | """ Signal fetching thread to terminate operation, then wait / block
261 | until thread terminates. Also clear the fetch result cache.
262 | """
263 | self.logger.debug("cancel_async_fetchall ...")
264 | self._cancel_async_event.set()
265 | if self._fetch_thread.is_alive():
266 | self._fetch_thread.join()
267 | with self._fetch_cv:
268 | self._fetch_res = []
269 |
270 | def execute(self, query, parameters = None, event: Event = None) -> Cursor:
271 | self.logger.debug("Execute: %s", query)
272 | with self._lock:
273 | self.cursor = self.conn.cursor()
274 | try:
275 | self._execution_err = None
276 | self.update_status(connStatus.EXECUTING)
277 | self.cursor.execute(query, parameters)
278 | self.update_status(connStatus.IDLE)
279 | self.update_execution_status(executionStatus.OK)
280 | self.query = query
281 | except DatabaseError as e:
282 | self.update_status(connStatus.IDLE)
283 | self.update_execution_status(executionStatus.FAIL)
284 | self._execution_err = str(e)
285 | self.logger.warning("Execution error: %s", str(e))
286 | if event is not None:
287 | event.set()
288 | return self.cursor
289 |
290 | def async_execute(self, query) -> Cursor:
291 | """ async_ is a misnomer here. It does execute in a new thread
292 | however it will also wait for execution to complete. At this time
293 | this helps us with registering KeyboardInterrupt during cyanodbc.
294 | execute only; it may evolve to have more true async-like behavior.
295 | """
296 | self.close_cursor()
297 | exec_event = Event()
298 | self._execution_thread = Thread(
299 | target = self.execute,
300 | kwargs = {"query": query, "parameters": None, "event": exec_event},
301 | daemon = True)
302 | self._execution_thread.start()
303 | # Will block but can be interrupted
304 | exec_event.wait()
305 | return self.cursor
306 |
307 | def list_catalogs(self) -> list:
308 | # pyodbc note
309 | # return conn.cursor().tables(catalog = "%").fetchall()
310 | res = []
311 | if self.status != connStatus.IDLE:
312 | return res
313 | try:
314 | if self.conn.connected():
315 | self.logger.debug("Calling list_catalogs...")
316 | with self._lock:
317 | res = self.conn.list_catalogs()
318 | self.logger.debug("list_catalogs: done")
319 | except DatabaseError as e:
320 | self.update_status(connStatus.ERROR)
321 | self.logger.warning("list_catalogs: %s", str(e))
322 |
323 | return res
324 |
325 | def list_schemas(self, catalog = None) -> list:
326 | res = []
327 |
328 | # We only trust this generic implementation if attempting to list
329 | # schemata in curent catalog (or catalog argument is None)
330 | if catalog is not None and not catalog == self.current_catalog():
331 | return res
332 |
333 | if self.status != connStatus.IDLE:
334 | return res
335 | try:
336 | if self.conn.connected():
337 | self.logger.debug("Calling list_schemas...")
338 | with self._lock:
339 | res = self.conn.list_schemas()
340 | self.logger.debug("list_schemas: done")
341 | except DatabaseError as e:
342 | self.update_status(connStatus.ERROR)
343 | self.logger.warning("list_schemas: %s", str(e))
344 |
345 | return res
346 |
347 | def find_tables(
348 | self,
349 | catalog = "",
350 | schema = "",
351 | table = "",
352 | type = "") -> list:
353 | res = []
354 |
355 | if self.status != connStatus.IDLE:
356 | return res
357 | try:
358 | if self.conn.connected():
359 | self.logger.debug("Calling find_tables: %s, %s, %s, %s",
360 | catalog, schema, table, type)
361 | with self._lock:
362 | res = self.conn.find_tables(
363 | catalog = catalog,
364 | schema = schema,
365 | table = table,
366 | type = type)
367 | self.logger.debug("find_tables: done")
368 | except DatabaseError as e:
369 | self.logger.warning("find_tables: %s.%s.%s, type %s: %s", catalog, schema, table, type, str(e))
370 |
371 | return res
372 |
373 | def find_columns(
374 | self,
375 | catalog = "",
376 | schema = "",
377 | table = "",
378 | column = "") -> list:
379 | res = []
380 | if self.status != connStatus.IDLE:
381 | return res
382 |
383 | try:
384 | if self.conn.connected():
385 | self.logger.debug("Calling find_columns: %s, %s, %s, %s",
386 | catalog, schema, table, column)
387 | with self._lock:
388 | res = self.conn.find_columns(
389 | catalog = catalog,
390 | schema = schema,
391 | table = table,
392 | column = column)
393 | self.logger.debug("find_columns: done")
394 | except DatabaseError as e:
395 | self.logger.warning("find_columns: %s.%s.%s, column %s: %s", catalog, schema, table, column, str(e))
396 |
397 | return res
398 |
399 | def find_procedures(
400 | self,
401 | catalog = "",
402 | schema = "",
403 | procedure = "") -> list:
404 | res = []
405 |
406 | if self.status != connStatus.IDLE:
407 | return res
408 |
409 | try:
410 | if self.conn.connected():
411 | self.logger.debug("Calling find_procedures: %s, %s, %s",
412 | catalog, schema, procedure)
413 | with self._lock:
414 | res = self.conn.find_procedures(
415 | catalog = catalog,
416 | schema = schema,
417 | procedure = procedure)
418 | self.logger.debug("find_procedures: done")
419 | except DatabaseError as e:
420 | self.logger.warning("find_procedures: %s.%s.%s: %s", catalog, schema, procedure, str(e))
421 |
422 | return res
423 |
424 | def find_procedure_columns(
425 | self,
426 | catalog = "",
427 | schema = "",
428 | procedure = "",
429 | column = "") -> list:
430 | res = []
431 |
432 | if self.status != connStatus.IDLE:
433 | return res
434 |
435 | try:
436 | if self.conn.connected():
437 | self.logger.debug("Calling find_procedure_columns: %s, %s, %s, %s",
438 | catalog, schema, procedure, column)
439 | with self._lock:
440 | res = self.conn.find_procedure_columns(
441 | catalog = catalog,
442 | schema = schema,
443 | procedure = procedure,
444 | column = column)
445 | self.logger.debug("find_procedure_columns: done")
446 | except DatabaseError as e:
447 | self.logger.warning("find_procedure_columns: %s.%s.%s, column %s: %s", catalog, schema, procedure, column, str(e))
448 |
449 | return res
450 |
451 | def current_catalog(self) -> str:
452 | if self.conn.connected():
453 | return self.conn.catalog_name
454 | return None
455 |
456 | def connected(self) -> bool:
457 | return self.conn.connected()
458 |
459 | def catalog_support(self) -> bool:
460 | res = self.conn.get_info(SQLGetInfo.SQL_CATALOG_NAME)
461 | return res == True or res == 'Y'
462 | # pyodbc note
463 | # return self.conn.getinfo(pyodbc.SQL_CATALOG_NAME) == True or self.conn.getinfo(pyodbc.SQL_CATALOG_NAME) == 'Y'
464 |
465 | def get_info(self, code: int) -> str:
466 | return self.conn.get_info(code)
467 |
468 | def close(self) -> None:
469 | self.logger.debug("close ...")
470 | # TODO: When disconnecting
471 | # We likely don't want to allow any exception to
472 | # propagate. Catch DatabaseError?
473 | if self.conn.connected():
474 | self.conn.close()
475 |
476 | def close_cursor(self) -> None:
477 | self.logger.debug("Close cursor ...")
478 | self.cancel_async_fetchall()
479 | if self.cursor:
480 | with self._lock:
481 | self.cursor.close()
482 | self.cursor = None
483 | self.query = None
484 | self.update_status(connStatus.IDLE)
485 |
486 | def cancel(self) -> None:
487 | self.logger.debug("cancel ...")
488 | self.cancel_async_fetchall()
489 | if self.cursor:
490 | # Should not hold _lock here. Point here is to cancel execution
491 | # that might be taking place in a separate thread where the execution
492 | # lock is being held
493 | self.cursor.cancel()
494 | if self._execution_thread.is_alive():
495 | self._execution_thread.join()
496 | self.query = None
497 | self.update_status(connStatus.IDLE)
498 |
499 | def preview_query(
500 | self,
501 | name,
502 | obj_type = "table",
503 | filter_query = "",
504 | limit = -1) -> str:
505 | """ Currently we only have a generic implementation for tables and
506 | views. Otherwise (functions) we return None
507 | """
508 | if obj_type == "table" or obj_type == "view":
509 | qry = "SELECT * FROM " + name + " " + filter_query
510 | if limit > 0:
511 | qry = qry + " LIMIT " + str(limit)
512 | else:
513 | qry = None
514 | return qry
515 |
516 | def formatted_fetch(self, size, cols, format_name = "psql"):
517 | while True:
518 | res = self.fetch_from_cache(size, wait = True)
519 | if len(res) < 1 and self.status != connStatus.FETCHING:
520 | break
521 | if len(res) > 0:
522 | yield "\n".join(
523 | formatter.format_output(
524 | res,
525 | cols,
526 | format_name = format_name))
527 |
528 | connWrappers = {}
529 |
530 | class MSSQL(sqlConnection):
531 | def find_tables(
532 | self,
533 | catalog = "",
534 | schema = "",
535 | table = "",
536 | type = "") -> list:
537 | """ FreeTDS does not allow us to query catalog == '', and
538 | schema = '' which, according to the ODBC spec for SQLTables should
539 | return tables outside of any catalog/schema. In the case of FreeTDS
540 | what gets passed to the sp_tables sproc is null, which in turn
541 | is interpreted as a wildcard. For the time being intercept
542 | these queries here (used in auto completion) and return empty
543 | set. """
544 |
545 | if catalog == "\x00" and schema == "\x00":
546 | return []
547 |
548 | return super().find_tables(
549 | catalog = catalog,
550 | schema = schema,
551 | table = table,
552 | type = type)
553 |
554 | def list_schemas(self, catalog = None) -> list:
555 | """ Optimization for listing out-of-database schemas by
556 | always querying catalog.sys.schemas. """
557 | res = []
558 | if self.status != connStatus.IDLE:
559 | return res
560 |
561 | qry = "SELECT name FROM {catalog}.sys.schemas " \
562 | "WHERE name NOT IN ('db_owner', 'db_accessadmin', " \
563 | "'db_securityadmin', 'db_ddladmin', 'db_backupoperator', " \
564 | "'db_datareader', 'db_datawriter', 'db_denydatareader', " \
565 | "'db_denydatawriter')"
566 |
567 | if catalog is None and self.current_catalog():
568 | catalog_local = self.current_catalog()
569 | else:
570 | # We are going to be outright executing, versus
571 | # using the ODBC API.
572 | # let's make sure there is nothing escaped here
573 | catalog_local = self.unsanitize_search_string(catalog)
574 |
575 | if catalog_local:
576 | try:
577 | self.logger.debug("Calling list_schemas...")
578 | crsr = self.execute(qry.format(catalog = catalog_local))
579 | res = crsr.fetchall()
580 | crsr.close()
581 | self.logger.debug("Calling list_schemas: done")
582 | schemas = [r[0] for r in res]
583 | if len(schemas):
584 | return schemas
585 | except DatabaseError as e:
586 | # execute has an exception handler, but the cursor calls may
587 | # throw
588 | self.close_cursor()
589 | self.logger.warning("MSSQL list_schemas: %s", str(e))
590 |
591 | return super().list_schemas(catalog = catalog)
592 |
593 | def preview_query(
594 | self,
595 | name,
596 | obj_type = "table",
597 | filter_query = "",
598 | limit = -1) -> str:
599 |
600 | if obj_type == "table" or obj_type == "view":
601 | qry = " * FROM " + name + " " + filter_query
602 | if limit > 0:
603 | qry = "SELECT TOP " + str(limit) + qry
604 | else:
605 | qry = "SELECT" + qry
606 | elif obj_type == "function":
607 | # Sproc names in SQLServer come back with
608 | # catalog.schema.name;INT with the trailing suffix
609 | # not useful
610 | name_sanitized = sub("(;\\d{0,})(\")$", "\\2", name)
611 | catalog_local = sub("(.*)\\.(.*)\\.(.*)", "\\1", name)
612 | qry = "SELECT definition FROM {catalog}.sys.sql_modules " \
613 | "WHERE object_id = (OBJECT_ID(N'{name}'))"
614 | qry = qry.format(catalog = catalog_local, name = name_sanitized)
615 | else:
616 | qry = None
617 |
618 | return qry
619 |
620 | class PSSQL(sqlConnection):
621 | def find_tables(
622 | self,
623 | catalog = "",
624 | schema = "",
625 | table = "",
626 | type = "") -> list:
627 | """ At least the psql odbc driver I am using has an annoying habbit
628 | of treating the catalog and schema fields interchangible, which
629 | in turn screws up with completion"""
630 |
631 | if not catalog in [self.current_catalog(), self.sanitize_search_string(self.current_catalog())]:
632 | return []
633 |
634 | return super().find_tables(
635 | catalog = catalog,
636 | schema = schema,
637 | table = table,
638 | type = type)
639 |
640 | def find_procedures(
641 | self,
642 | catalog = "",
643 | schema = "",
644 | procedure = "") -> list:
645 | """ At least the psql odbc driver I am using has an annoying habbit
646 | of treating the catalog and schema fields interchangible, which
647 | in turn screws up with completion"""
648 |
649 | if not catalog in [self.current_catalog(), self.sanitize_search_string(self.current_catalog())]:
650 | return []
651 |
652 | return super().find_procedures(
653 | catalog = catalog,
654 | schema = schema,
655 | procedure = procedure)
656 |
657 | def find_columns(
658 | self,
659 | catalog = "",
660 | schema = "",
661 | table = "",
662 | column = "") -> list:
663 | """ At least the psql odbc driver I am using has an annoying habbit
664 | of treating the catalog and schema fields interchangible, which
665 | in turn screws up with completion"""
666 |
667 | if not catalog in [self.current_catalog(), self.sanitize_search_string(self.current_catalog())]:
668 | return []
669 |
670 | return super().find_columns(
671 | catalog = catalog,
672 | schema = schema,
673 | table = table,
674 | column = column)
675 |
676 | def find_procedure_columns(
677 | self,
678 | catalog = "",
679 | schema = "",
680 | procedure = "",
681 | column = "") -> list:
682 | """ At least the psql odbc driver I am using has an annoying habbit
683 | of treating the catalog and schema fields interchangible, which
684 | in turn screws up with completion. In addition wildcards in the column
685 | field, seem to not work - but an empty string does."""
686 |
687 | if not catalog in [self.current_catalog(), self.sanitize_search_string(self.current_catalog())]:
688 | return []
689 |
690 | if column == "%":
691 | column = ""
692 |
693 | return super().find_procedure_columns(
694 | catalog = catalog,
695 | schema = schema,
696 | procedure = procedure,
697 | column = column)
698 |
699 | class MySQL(sqlConnection):
700 |
701 | def list_schemas(self, catalog = None) -> list:
702 | """ Only catalogs for MySQL, it seems,
703 | however, list_schemas returns [""] which
704 | causes blank entries to show up in auto
705 | completion. Also confuses some of the checks we have
706 | that look for len(list_schemas) < 1 to decide whether
707 | to fall-back to find_tables. Make sure that for MySQL
708 | we do, in-fact fall-back to find_tables"""
709 | return []
710 |
711 | def current_catalog(self) -> str:
712 | if self.conn.connected():
713 | res = self.conn.catalog_name
714 | if res == "null":
715 | res = ""
716 | return res
717 |
718 | class Snowflake(sqlConnection):
719 |
720 | def find_tables(
721 | self,
722 | catalog = "",
723 | schema = "",
724 | table = "",
725 | type = "") -> list:
726 |
727 | type = type.upper()
728 | return super().find_tables(
729 | catalog = catalog,
730 | schema = schema,
731 | table = table,
732 | type = type)
733 |
734 | connWrappers["MySQL"] = MySQL
735 | connWrappers["Microsoft SQL Server"] = MSSQL
736 | connWrappers["PostgreSQL"] = PSSQL
737 | connWrappers["Snowflake"] = Snowflake
738 |
--------------------------------------------------------------------------------
/odbcli/dbmetadata.py:
--------------------------------------------------------------------------------
1 | from threading import Lock, Event, Thread
2 |
3 | class DbMetadata():
4 | def __init__(self) -> None:
5 | self._lock = Lock()
6 | self._dbmetadata = {'table': {}, 'view': {}, 'function': {},
7 | 'datatype': {}}
8 |
9 | def extend_catalogs(self, names: list) -> None:
10 | with self._lock:
11 | for metadata in self._dbmetadata.values():
12 | for catalog in names:
13 | metadata[catalog.lower()] = (catalog, {})
14 | return
15 |
16 | def get_catalogs(self, obj_type: str = "table", cased: bool = True) -> list:
17 | """ Retrieve catalogs as the keys for _dbmetadata[obj_type]
18 | If no keys are found it returns None.
19 | """
20 | with self._lock:
21 | if cased:
22 | res = [casedkey for casedkey, mappedvalue in self._dbmetadata[obj_type].values()]
23 | else:
24 | res = list(self._dbmetadata[obj_type].keys())
25 |
26 |
27 | if len(res) == 0:
28 | return None
29 |
30 | return res
31 |
32 | def extend_schemas(self, catalog, names: list) -> None:
33 | """ This method will force/create [catalog] dictionary
34 | in the event that len(names) > 0, and overwrite if
35 | anything was there to begin with.
36 | """
37 | catlower = catalog.lower()
38 | cat_cased = catalog
39 | if len(names):
40 | with self._lock:
41 | for metadata in self._dbmetadata.values():
42 | # Preserve casing if an entry already there
43 | if catlower in metadata.keys() and len(metadata[catlower]):
44 | cat_cased = metadata[catlower][0]
45 | metadata[catlower] = (cat_cased, {})
46 | for schema in names:
47 | metadata[catlower][1][schema.lower()] = (schema, {})
48 | return
49 |
50 | def get_schemas(self, catalog: str, obj_type: str = "table", cased: bool = True) -> list:
51 | """ Retrieve schemas as the keys for _dbmetadata[obj_type][catalog]
52 | If catalog is not part of the _dbmetadata[obj_type] keys will return
53 | None.
54 | """
55 |
56 | catlower = catalog.lower()
57 | cats = self.get_catalogs(obj_type = obj_type, cased = False)
58 | if cats is None or catlower not in cats:
59 | return None
60 |
61 | with self._lock:
62 | if cased:
63 | res = [casedkey for casedkey, mappedvalue in self._dbmetadata[obj_type][catlower][1].values()]
64 | else:
65 | res = list(self._dbmetadata[obj_type][catlower][1].keys())
66 |
67 |
68 | return res
69 |
70 | def extend_objects(self, catalog, schema, names: list, obj_type: str) -> None:
71 | catlower = catalog.lower()
72 | schlower = schema.lower()
73 | if len(names):
74 | with self._lock:
75 | for otype in self._dbmetadata.keys():
76 | # Loop over tables, views, functions
77 | if catlower not in self._dbmetadata[otype].keys():
78 | self._dbmetadata[otype][catlower] = (catalog, {})
79 | if schlower not in self._dbmetadata[otype][catlower][1].keys():
80 | self._dbmetadata[otype][catlower][1][schlower] = (schema, {})
81 | for obj in names:
82 | self._dbmetadata[obj_type][catlower][1][schlower][1][obj.lower()] = (obj, {})
83 | # If we passed nothing then take out that element entirely out
84 | # of the dict
85 | else:
86 | with self._lock:
87 | del self._dbmetadata[obj_type][catlower][1][schlower]
88 |
89 | return
90 |
91 | def get_objects(self, catalog: str, schema: str, obj_type: str = "table") -> list:
92 | """ Retrieve objects as the keys for _dbmetadata[obj_type][catalog][schema]
93 | If catalog is not part of the _dbmetadata[obj_type] keys, or schema
94 | not one of the keys in _dbmetadata[obj_type][catalog] will return None
95 | """
96 |
97 | catlower = catalog.lower()
98 | schlower = schema.lower()
99 | schemas = self.get_schemas(catalog = catalog, obj_type = obj_type, cased = False)
100 | if schemas is None or schlower not in schemas:
101 | return None
102 |
103 | with self._lock:
104 | res = [casedkey for casedkey, mappedvalue in self._dbmetadata[obj_type][catlower][1][schlower][1].values()]
105 |
106 | return list(res)
107 |
108 | def reset_metadata(self) -> None:
109 | with self._lock:
110 | self._dbmetadata = {'table': {}, 'view': {}, 'function': {},
111 | 'datatype': {}}
112 |
113 | @property
114 | def data(self) -> dict:
115 | return self._dbmetadata
116 |
--------------------------------------------------------------------------------
/odbcli/disconnect_dialog.py:
--------------------------------------------------------------------------------
1 | from prompt_toolkit.widgets import Button, Dialog, Label
2 | from prompt_toolkit.layout.containers import ConditionalContainer
3 | from prompt_toolkit.layout.dimension import Dimension as D
4 | from prompt_toolkit.filters import is_done
5 | from .conn import connWrappers
6 | from .filters import ShowDisconnectDialog
7 |
8 | def disconnect_dialog(my_app: "sqlApp"):
9 | def yes_handler() -> None:
10 | # This is not preferred since completer may currently have
11 | # a different connection attached.
12 | # my_app.completer.reset_completions()
13 | obj = my_app.selected_object
14 | obj.conn.dbmetadata.reset_metadata()
15 | obj.collapse()
16 | obj.conn.close()
17 | if my_app.active_conn is obj.conn:
18 | my_app.active_conn = None
19 | my_app.show_disconnect_dialog = False
20 | my_app.show_sidebar = True
21 | my_app.application.layout.focus("sidebarbuffer")
22 |
23 | def rc_handler() -> None:
24 | obj = my_app.selected_object
25 | obj.conn.dbmetadata.reset_metadata()
26 | my_app.show_disconnect_dialog = False
27 | my_app.show_sidebar = True
28 | my_app.application.layout.focus("sidebarbuffer")
29 |
30 | def no_handler() -> None:
31 | my_app.show_disconnect_dialog = False
32 | my_app.show_sidebar = True
33 | my_app.application.layout.focus("sidebarbuffer")
34 |
35 | dialog = Dialog(
36 | title = lambda: my_app.selected_object.name,
37 | body = Label(text = "Disconnect or Reset Completions?",
38 | dont_extend_height = True),
39 | buttons = [
40 | Button(text = "Disconnect", handler = yes_handler),
41 | Button(text = "Reset Completions", handler = rc_handler, width = 20),
42 | Button(text = "Cancel", handler = no_handler),
43 | ],
44 | width = D(min = 10, preferred = 50),
45 | with_background = False,
46 | )
47 |
48 | return ConditionalContainer(
49 | content = dialog,
50 | filter = ShowDisconnectDialog(my_app) & ~is_done
51 | )
52 |
--------------------------------------------------------------------------------
/odbcli/filters.py:
--------------------------------------------------------------------------------
1 | from prompt_toolkit.filters import Filter
2 | from prompt_toolkit.application import get_app
3 | from re import sub, findall
4 |
5 | class SqlAppFilter(Filter):
6 | def __init__(self, sql_app: "sqlApp") -> None:
7 | super().__init__()
8 | self.my_app = sql_app
9 |
10 | def __call__(self) -> bool:
11 | raise NotImplementedError
12 |
13 | class ShowSidebar(SqlAppFilter):
14 | def __call__(self) -> bool:
15 | return self.my_app.show_sidebar and not self.my_app.show_exit_confirmation
16 |
17 | class ShowLoginPrompt(SqlAppFilter):
18 | def __call__(self) -> bool:
19 | return self.my_app.show_login_prompt
20 |
21 | class ShowPreview(SqlAppFilter):
22 | def __call__(self) -> bool:
23 | return self.my_app.show_preview
24 |
25 | class ShowDisconnectDialog(SqlAppFilter):
26 | def __call__(self) -> bool:
27 | return self.my_app.show_disconnect_dialog
28 |
29 | class MultilineFilter(SqlAppFilter):
30 | def _is_open_quote(self, sql: str):
31 | """ To implement """
32 | return False
33 | def _is_query_executable(self, sql: str):
34 | # A complete command is an sql statement that ends with a 'GO', unless
35 | # there's an open quote surrounding it, as is common when writing a
36 | # CREATE FUNCTION command
37 | if sql is not None and sql != "":
38 | # remove comments
39 | #esql = sqlparse.format(sql, strip_comments=True)
40 | # check for open comments
41 | # remove all closed quotes to isolate instances of open comments
42 | sql_no_quotes = sub(r'".*?"|\'.*?\'', '', sql)
43 | is_open_comment = len(findall(r'\/\*', sql_no_quotes)) > 0
44 | # check that 'go' is only token on newline
45 | lines = sql.split('\n')
46 | lastline = lines[len(lines) - 1].lower().strip()
47 | is_valid_go_on_lastline = lastline == 'go'
48 | # check that 'go' is on last line, not in open quotes, and there's no open
49 | # comment with closed comments and quotes removed.
50 | # NOTE: this method fails when GO follows a closing '*/' block comment on the same line,
51 | # we've taken a dependency with sqlparse
52 | # (https://github.com/andialbrecht/sqlparse/issues/484)
53 | return not self._is_open_quote(sql) and not is_open_comment and is_valid_go_on_lastline
54 | return False
55 |
56 |
57 | def _multiline_exception(self, text: str):
58 | text = text.strip()
59 | return (
60 | text.startswith('\\') or # Special Command
61 | text.endswith(r'\e') or # Ended with \e which should launch the editor
62 | self._is_query_executable(text) or # A complete SQL command
63 | (text.endswith(';')) or # GO doesn't work everywhere
64 | (text == 'exit') or # Exit doesn't need semi-colon
65 | (text == 'quit') or # Quit doesn't need semi-colon
66 | (text == ':q') or # To all the vim fans out there
67 | (text == '') # Just a plain enter without any text
68 | )
69 |
70 | def __call__(self) -> bool:
71 | doc = get_app().layout.get_buffer_by_name("defaultbuffer").document
72 | if not self.my_app.multiline:
73 | return False
74 | return not self._multiline_exception(doc.text)
75 |
--------------------------------------------------------------------------------
/odbcli/layout.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from prompt_toolkit.layout.processors import AppendAutoSuggestion
3 | from prompt_toolkit.key_binding.vi_state import InputMode
4 | from prompt_toolkit.layout.containers import VSplit, HSplit, Window, ConditionalContainer, FloatContainer, Container, Float, ScrollOffsets
5 | from prompt_toolkit.buffer import Buffer
6 | from prompt_toolkit.layout.controls import BufferControl, FormattedTextControl
7 | from prompt_toolkit.document import Document
8 | from prompt_toolkit.filters import Condition, has_focus, is_done, renderer_height_is_known
9 | from prompt_toolkit.layout.layout import Layout
10 | from prompt_toolkit.completion import DynamicCompleter, ThreadedCompleter
11 | from prompt_toolkit.history import History, InMemoryHistory, FileHistory, ThreadedHistory
12 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory, ThreadedAutoSuggest
13 | from prompt_toolkit.layout.dimension import Dimension
14 | from prompt_toolkit.widgets import SearchToolbar
15 | from prompt_toolkit.formatted_text import StyleAndTextTuples, to_formatted_text
16 | from prompt_toolkit.layout.menus import CompletionsMenu
17 | from prompt_toolkit.application import get_app
18 | from prompt_toolkit.mouse_events import MouseEvent
19 | from prompt_toolkit.lexers import PygmentsLexer
20 | from prompt_toolkit.selection import SelectionType
21 | from pygments.lexers.sql import SqlLexer
22 | from os.path import expanduser
23 | from .sidebar import sql_sidebar, sql_sidebar_help, show_sidebar_button_info, sql_sidebar_navigation
24 | from .loginprompt import login_prompt
25 | from .disconnect_dialog import disconnect_dialog
26 | from .preview import PreviewElement
27 | from .filters import ShowLoginPrompt, ShowSidebar, MultilineFilter
28 | from .utils import if_mousedown
29 | from .conn import connStatus
30 | from .config import config_location, ensure_dir_exists
31 |
32 | def get_inputmode_fragments(my_app: "sqlApp") -> StyleAndTextTuples:
33 | """
34 | Return current input mode as a list of (token, text) tuples for use in a
35 | toolbar.
36 | """
37 | app = get_app()
38 |
39 | @if_mousedown
40 | def toggle_vi_mode(mouse_event: MouseEvent) -> None:
41 | my_app.vi_mode = not my_app.vi_mode
42 |
43 | token = "class:status-toolbar"
44 | input_mode_t = "class:status-toolbar.input-mode"
45 |
46 | mode = app.vi_state.input_mode
47 | result: StyleAndTextTuples = []
48 | append = result.append
49 |
50 | # if my_app.title:
51 | if False:
52 | result.extend(to_formatted_text(my_app.title))
53 |
54 | append((input_mode_t, "[F-4] ", toggle_vi_mode))
55 |
56 | # InputMode
57 | if my_app.vi_mode:
58 | recording_register = app.vi_state.recording_register
59 | if recording_register:
60 | append((token, " "))
61 | append((token + " class:record", "RECORD({})".format(recording_register)))
62 | append((token, " - "))
63 |
64 | if app.current_buffer.selection_state is not None:
65 | if app.current_buffer.selection_state.type == SelectionType.LINES:
66 | append((input_mode_t, "Vi (VISUAL LINE)", toggle_vi_mode))
67 | elif app.current_buffer.selection_state.type == SelectionType.CHARACTERS:
68 | append((input_mode_t, "Vi (VISUAL)", toggle_vi_mode))
69 | append((token, " "))
70 | elif app.current_buffer.selection_state.type == SelectionType.BLOCK:
71 | append((input_mode_t, "Vi (VISUAL BLOCK)", toggle_vi_mode))
72 | append((token, " "))
73 | elif mode in (InputMode.INSERT, "vi-insert-multiple"):
74 | append((input_mode_t, "Vi (INSERT)", toggle_vi_mode))
75 | append((token, " "))
76 | elif mode == InputMode.NAVIGATION:
77 | append((input_mode_t, "Vi (NAV)", toggle_vi_mode))
78 | append((token, " "))
79 | elif mode == InputMode.REPLACE:
80 | append((input_mode_t, "Vi (REPLACE)", toggle_vi_mode))
81 | append((token, " "))
82 | else:
83 | if app.emacs_state.is_recording:
84 | append((token, " "))
85 | append((token + " class:record", "RECORD"))
86 | append((token, " - "))
87 |
88 | append((input_mode_t, "Emacs", toggle_vi_mode))
89 | append((token, " "))
90 |
91 | append((input_mode_t, "[C-q] Exit Client", ))
92 |
93 | return result
94 |
95 | def get_connection_fragments(my_app: "sqlApp") -> StyleAndTextTuples:
96 | """
97 | Return current input mode as a list of (token, text) tuples for use in a
98 | toolbar.
99 | """
100 | app = get_app()
101 | status = my_app.active_conn.status if my_app.active_conn else connStatus.DISCONNECTED
102 | if status == connStatus.FETCHING:
103 | token = "class:status-toolbar.conn-fetching"
104 | status_text = "Fetching"
105 | elif status == connStatus.EXECUTING:
106 | token = "class:status-toolbar.conn-executing"
107 | status_text = "Executing"
108 | elif status == connStatus.ERROR:
109 | token = "class:status-toolbar.conn-executing"
110 | status_text = "Unexpected Error"
111 | elif status == connStatus.DISCONNECTED:
112 | token = "class:status-toolbar.conn-fetching"
113 | status_text = "Disconnected"
114 | else:
115 | token = "class:status-toolbar.conn-idle"
116 | status_text = "Idle"
117 |
118 | result: StyleAndTextTuples = []
119 | append = result.append
120 |
121 | append((token, " " + status_text))
122 | return result
123 |
124 | def exit_confirmation(
125 | my_app: "sqlApp", style = "class:exit-confirmation"
126 | ) -> Container:
127 | """
128 | Create `Layout` for the exit message.
129 | """
130 |
131 | def get_text_fragments() -> StyleAndTextTuples:
132 | # Show "Do you really want to exit?"
133 | return [
134 | (style, "\n %s ([y]/n)" % my_app.exit_message),
135 | ("[SetCursorPosition]", ""),
136 | (style, " \n"),
137 | ]
138 |
139 | visible = ~is_done & Condition(lambda: my_app.show_exit_confirmation)
140 |
141 | return ConditionalContainer(
142 | content=Window(
143 | FormattedTextControl(get_text_fragments), style=style
144 | ),
145 | filter=visible,
146 | )
147 |
148 |
149 |
150 | def status_bar(my_app: "sqlApp") -> Container:
151 | """
152 | Create the `Layout` for the status bar.
153 | """
154 | TB = "class:status-toolbar"
155 |
156 | def get_text_fragments() -> StyleAndTextTuples:
157 |
158 | result: StyleAndTextTuples = []
159 | append = result.append
160 |
161 | append((TB, " "))
162 | result.extend(get_inputmode_fragments(my_app))
163 | append((TB, " "))
164 | result.extend(get_connection_fragments(my_app))
165 |
166 |
167 | return result
168 |
169 | return ConditionalContainer(
170 | content=Window(content=FormattedTextControl(get_text_fragments), style=TB),
171 | filter=~is_done
172 | & renderer_height_is_known
173 | & Condition(
174 | lambda: not my_app.show_exit_confirmation
175 | ),
176 | )
177 |
178 | def sql_line_prefix(
179 | line_number: int,
180 | wrap_count: int,
181 | my_app: "sqlApp"
182 | ) -> StyleAndTextTuples:
183 | if my_app.active_conn is not None:
184 | sqlConn = my_app.active_conn
185 | prompt = sqlConn.username + "@" + sqlConn.dsn + ":" + sqlConn.current_catalog() + " > "
186 | else:
187 | prompt = "> "
188 | if line_number == 0 and wrap_count == 0:
189 | return to_formatted_text([("class:prompt", prompt)])
190 | prompt_width = len(prompt)
191 | return [("class:prompt.dots", "." * (prompt_width - 1) + " ")]
192 |
193 | class sqlAppLayout:
194 | def __init__(
195 | self,
196 | my_app: "sqlApp"
197 | ) -> None:
198 |
199 | self.my_app = my_app
200 | self.search_field = SearchToolbar()
201 | history_file = config_location() + 'history'
202 | ensure_dir_exists(history_file)
203 | hist = ThreadedHistory(FileHistory(expanduser(history_file)))
204 | self.input_buffer = Buffer(
205 | name = "defaultbuffer",
206 | tempfile_suffix = ".py",
207 | multiline = MultilineFilter(self.my_app),
208 | history = hist,
209 | completer = ThreadedCompleter(self.my_app.completer),
210 | auto_suggest = ThreadedAutoSuggest(AutoSuggestFromHistory()),
211 | complete_while_typing = Condition(
212 | lambda: self.my_app.active_conn is not None
213 | )
214 | )
215 | main_win_control = BufferControl(
216 | buffer = self.input_buffer,
217 | lexer = PygmentsLexer(SqlLexer),
218 | search_buffer_control = self.search_field.control,
219 | include_default_input_processors = False,
220 | input_processors = [AppendAutoSuggestion()],
221 | preview_search = True
222 | )
223 |
224 | self.main_win = Window(
225 | main_win_control,
226 | height = (
227 | lambda: (
228 | None
229 | if get_app().is_done
230 | else (Dimension(min = self.my_app.min_num_menu_lines) if not self.my_app.show_preview else Dimension(min = self.my_app.min_num_menu_lines, preferred = 180))
231 | )
232 | ),
233 | get_line_prefix = partial(sql_line_prefix, my_app = self.my_app),
234 | scroll_offsets=ScrollOffsets(bottom = 1, left = 4, right = 4)
235 | )
236 |
237 | preview_element = PreviewElement(self.my_app)
238 | self.lprompt = login_prompt(self.my_app)
239 | self.preview = preview_element.create_container()
240 | self.disconnect_dialog = disconnect_dialog(self.my_app)
241 | container = HSplit([
242 | VSplit([
243 | FloatContainer(
244 | content=HSplit(
245 | [
246 | self.main_win,
247 | self.search_field,
248 | ]
249 | ),
250 | floats=[
251 | Float(
252 | bottom = 1,
253 | left = 1,
254 | right = 0,
255 | content = sql_sidebar_help(self.my_app),
256 | ),
257 | Float(
258 | content = self.lprompt
259 | ),
260 | Float(
261 | content = self.preview,
262 | ),
263 | preview_element.create_completion_float(),
264 | Float(
265 | content = self.disconnect_dialog,
266 | ),
267 | Float(
268 | left = 2,
269 | bottom = 1,
270 | content = exit_confirmation(self.my_app)
271 | ),
272 | Float(
273 | xcursor = True,
274 | ycursor = True,
275 | transparent = True,
276 | content = CompletionsMenu(
277 | scroll_offset = 1,
278 | max_height = 16,
279 | extra_filter = has_focus(self.input_buffer)
280 | )
281 | )
282 | ]
283 | ),
284 | ConditionalContainer(
285 | content = sql_sidebar(self.my_app),
286 | filter=ShowSidebar(self.my_app) & ~is_done,
287 | )
288 | ]),
289 | VSplit(
290 | [status_bar(self.my_app), show_sidebar_button_info(self.my_app)]
291 | )
292 | ])
293 |
294 | def accept(buff):
295 | app = get_app()
296 | app.exit(result = ["non-preview", buff.text])
297 | app.pre_run_callables.append(buff.reset)
298 | return True
299 |
300 | self.input_buffer.accept_handler = accept
301 | self.layout = Layout(container, focused_element = self.main_win)
302 |
--------------------------------------------------------------------------------
/odbcli/loginprompt.py:
--------------------------------------------------------------------------------
1 | from prompt_toolkit.buffer import Buffer
2 | from prompt_toolkit.layout.dimension import Dimension as D
3 | from prompt_toolkit.widgets import Button, Dialog, Label
4 | from prompt_toolkit.widgets import TextArea
5 | from prompt_toolkit.layout.containers import HSplit, ConditionalContainer, WindowAlign, Window
6 | from prompt_toolkit.filters import is_done
7 | from cyanodbc import ConnectError, DatabaseError, SQLGetInfo
8 | from .conn import connWrappers, sqlConnection
9 | from .filters import ShowLoginPrompt
10 |
11 | def login_prompt(my_app: "sqlApp"):
12 |
13 | def ok_handler() -> None:
14 | my_app.application.layout.focus(uidTextfield)
15 | obj = my_app.selected_object
16 | try:
17 | obj.conn.connect(username = uidTextfield.text, password = pwdTextfield.text)
18 | # Query the type of back-end and instantiate an appropriate class
19 | dbms = obj.conn.get_info(SQLGetInfo.SQL_DBMS_NAME)
20 | # Now clone object
21 | cls = connWrappers[dbms] if dbms in connWrappers.keys() else sqlConnection
22 | newConn = cls(
23 | dsn = obj.conn.dsn,
24 | conn = obj.conn.conn,
25 | username = obj.conn.username,
26 | password = obj.conn.password)
27 | obj.conn.close()
28 | newConn.connect()
29 | obj.conn = newConn
30 | my_app.active_conn = obj.conn
31 | # OG some thread locking may be needed here
32 | obj.expand()
33 | except ConnectError as e:
34 | msgLabel.text = "Connect failed"
35 | else:
36 | msgLabel.text = ""
37 | my_app.show_login_prompt = False
38 | my_app.show_sidebar = True
39 | my_app.application.layout.focus("sidebarbuffer")
40 |
41 | uidTextfield.text = ""
42 | pwdTextfield.text = ""
43 |
44 | def cancel_handler() -> None:
45 | msgLabel.text = ""
46 | my_app.application.layout.focus(uidTextfield)
47 | my_app.show_login_prompt = False
48 | my_app.show_sidebar = True
49 | my_app.application.layout.focus("sidebarbuffer")
50 |
51 | def accept(buf: Buffer) -> bool:
52 | my_app.application.layout.focus(ok_button)
53 | return True
54 |
55 | ok_button = Button(text="OK", handler=ok_handler)
56 | cancel_button = Button(text="Cancel", handler=cancel_handler)
57 |
58 | pwdTextfield = TextArea(
59 | multiline=False, password=True, accept_handler=accept
60 | )
61 | uidTextfield = TextArea(
62 | multiline=False, password=False, accept_handler=accept
63 | )
64 | msgLabel = Label(text = "", dont_extend_height = True, style = "class:frame.label")
65 | msgLabel.window.align = WindowAlign.CENTER
66 | dialog = Dialog(
67 | title="Server Credentials",
68 | body=HSplit(
69 | [
70 | Label(text="Username ", dont_extend_height=True),
71 | uidTextfield,
72 | Label(text="Password", dont_extend_height=True),
73 | pwdTextfield,
74 | msgLabel
75 | ],
76 | padding=D(preferred=1, max=1),
77 | ),
78 | width = D(min = 10, preferred = 50),
79 | buttons=[ok_button, cancel_button],
80 | with_background = False
81 | )
82 |
83 | return ConditionalContainer(
84 | content = dialog,
85 | filter = ShowLoginPrompt(my_app) & ~is_done
86 | )
87 |
--------------------------------------------------------------------------------
/odbcli/odbclirc:
--------------------------------------------------------------------------------
1 | # vi: ft=dosini
2 | [main]
3 |
4 | # Multi-line mode allows breaking up the sql statements into multiple lines. If
5 | # this is set to True, then the end of the statements must have a semi-colon.
6 | # If this is set to False then sql statements can't be split into multiple
7 | # lines. End of line (return) is considered as the end of the statement.
8 | multi_line = False
9 |
10 | # log_file location.
11 | # In Unix/Linux: ~/.config/odbcli/log
12 | # In Windows: %USERPROFILE%\AppData\Local\dbcli\odbcli\log
13 | # %USERPROFILE% is typically C:\Users\{username}
14 | log_file = default
15 |
16 | # history_file location.
17 | # In Unix/Linux: ~/.config/odbcli/history
18 | # In Windows: %USERPROFILE%\AppData\Local\dbcli\odbcli\history
19 | # %USERPROFILE% is typically C:\Users\{username}
20 | history_file = default
21 |
22 | # Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO"
23 | # and "DEBUG". "NONE" disables logging.
24 | log_level = INFO
25 |
26 | # Mouse support. Limited mouse support for positioning the cursor, for
27 | # scrolling, and for clicking in the autocompletions menu. Recommended
28 | # False, as it may cause surprising behavior when trying to use the mouse.
29 | mouse_support = False
30 |
31 | # Default pager.
32 | # By default 'PAGER' environment variable is used
33 | # pager = less -SRXF
34 |
35 | # Number of lines to rserve for pager non-content output
36 | # When table format is psql, for less this is 1, pypager 2. Internally we use
37 | # this number to fetch just enough rows of data so that we can fit the table
38 | # columns at the top of each page - therefore this number also depends on the
39 | # format used
40 | pager_reserve_lines = 1
41 |
42 | # Timing of sql statement execution.
43 | timing = True
44 |
45 | # Table format. Possible values: psql, plain, simple, grid, fancy_grid, pipe,
46 | # ascii, double, github, orgtbl, rst, mediawiki, html, latex, latex_booktabs,
47 | # textile, moinmoin, jira, vertical, tsv, csv.
48 | # Recommended: psql, fancy_grid and grid.
49 | table_format = psql
50 |
51 | # Syntax Style. Possible values: manni, igor, xcode, vim, autumn, vs, rrt,
52 | # native, perldoc, borland, tango, emacs, friendly, monokai, paraiso-dark,
53 | # colorful, murphy, bw, pastie, paraiso-light, trac, default, fruity
54 | syntax_style = default
55 |
56 | # Keybindings:
57 | # When Vi mode is enabled you can use modal editing features offered by Vi in the REPL.
58 | # When Vi mode is disabled emacs keybindings such as Ctrl-A for home and Ctrl-E
59 | # for end are available in the REPL.
60 | vi = False
61 |
62 | # Number of lines to reserve for the suggestion menu
63 | min_num_menu_lines = 5
64 |
65 |
66 | # When viewing the results of a query in the main execution prompt of the client
67 | # this option determines how many (expressed as a multiplier of the number of
68 | # rows on the screen) results to pre-fetch before displaying the results screen.
69 | # The fetch operation is asynchronous / in a background process; this option
70 | # essentially determines the trade-off between smooth scrolling through the
71 | # results (higher multiplier) and how quickly the initial results are displayed
72 | # on the screen (lower multiplier). Generally a number >= 3 is recommended.
73 | fetch_chunk_multiplier = 5
74 |
75 | # When previewing a table we SELECT *. If preview_limit_rows is > 0
76 | # we attempt to limit the maximum number of rows fetched to this number.
77 | preview_limit_rows = 500
78 |
79 | # When previewing a table, the asynchronous operation will fetch all the
80 | # results from the SELECT query in chunks of fixed sizes asynchronously.
81 | # This number should generally be larger than the number of rows on the screen
82 | # but comfortably less than the total number of rows limited by
83 | # preview_limit_rows. More likely than not, you should not need to change this
84 | # number.
85 | preview_fetch_chunk_size = 100
86 |
87 | # Custom colors for the completion menu, toolbar, etc.
88 | [colors]
89 | completion-menu.completion.current = 'bg:#ffffff #000000'
90 | completion-menu.completion = 'bg:#008888 #ffffff'
91 | completion-menu.meta.completion.current = 'bg:#44aaaa #000000'
92 | completion-menu.meta.completion = 'bg:#448888 #ffffff'
93 | completion-menu.multi-column-meta = 'bg:#aaffff #000000'
94 | scrollbar.arrow = 'bg:#003333'
95 | scrollbar = 'bg:#00aaaa'
96 | selected = '#ffffff bg:#6666aa'
97 | search = '#ffffff bg:#4444aa'
98 | search.current = '#ffffff bg:#44aa44'
99 | bottom-toolbar = 'bg:#222222 #aaaaaa'
100 | bottom-toolbar.off = 'bg:#222222 #888888'
101 | bottom-toolbar.on = 'bg:#222222 #ffffff'
102 | search-toolbar = 'noinherit bold'
103 | search-toolbar.text = 'nobold'
104 | system-toolbar = 'noinherit bold'
105 | arg-toolbar = 'noinherit bold'
106 | arg-toolbar.text = 'nobold'
107 | bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold'
108 | bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold'
109 | literal.string = '#ba2121'
110 | literal.number = '#666666'
111 | keyword = 'bold #008000'
112 |
113 | preview-output-field = 'bg:#000044 #ffffff'
114 | preview-input-field = 'bg:#000066 #ffffff'
115 | preview-divider-line = '#000000'
116 | status-toolbar = "bg:#222222 #aaaaaa"
117 | status-toolbar.title = "underline"
118 | status-toolbar.inputmode = "bg:#222222 #ffffaa"
119 | status-toolbar.key = "bg:#000000 #888888"
120 | status-toolbar.pastemodeon = "bg:#aa4444 #ffffff"
121 | status-toolbar.cli-version = "bg:#222222 #ffffff bold"
122 | status-toolbar paste-mode-on = "bg:#aa4444 #ffffff"
123 | record = "bg:#884444 white"
124 | status-toolbar.input-mode = "#ffff44"
125 | status-toolbar.conn-executing = "bg:red #ffff44"
126 | status-toolbar.conn-fetching = "bg:yellow black"
127 | status-toolbar.conn-idle = "bg:#668866 #ffffff"
128 | # The options sidebar.
129 | sidebar = "bg:#bbbbbb #000000"
130 | sidebar.title = "bg:#668866 fg:#ffffff"
131 | sidebar.label = "bg:#bbbbbb fg:#222222"
132 | sidebar.status = "bg:#dddddd #000011"
133 | sidebar.label selected = "bg:#222222 #eeeeee bold"
134 | sidebar.status selected = "bg:#444444 #ffffff bold"
135 | sidebar.label active = "bg:#668866 #ffffff"
136 | sidebar.status active = "bg:#88AA88 #ffffff"
137 | sidebar.separator = "underline"
138 | sidebar.navigation.key = "bg:#bbddbb #000000 bold"
139 | sidebar.navigation.description = "bg:#dddddd #000011"
140 | sidebar.navigation = "bg:#dddddd"
141 | sidebar.helptext = "bg:#fdf6e3 #000011"
142 |
143 | # Exit confirmation
144 | exit-confirmation = "bg:#884444 #ffffff"
145 |
146 | # style classes for colored table output
147 | output.header = "#00ff5f bold"
148 | output.odd-row = ""
149 | output.even-row = ""
150 | output.null = "#808080"
151 |
--------------------------------------------------------------------------------
/odbcli/odbcstyle.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import pygments.styles
4 | from pygments.token import string_to_tokentype, Token
5 | from pygments.style import Style as PygmentsStyle
6 | from pygments.util import ClassNotFound
7 | from prompt_toolkit.styles.pygments import style_from_pygments_cls
8 | from prompt_toolkit.styles import merge_styles, Style
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 | # map Pygments tokens (ptk 1.0) to class names (ptk 2.0).
13 | TOKEN_TO_PROMPT_STYLE = {
14 | Token.Menu.Completions.Completion.Current: "completion-menu.completion.current",
15 | Token.Menu.Completions.Completion: "completion-menu.completion",
16 | Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current",
17 | Token.Menu.Completions.Meta: "completion-menu.meta.completion",
18 | Token.Menu.Completions.MultiColumnMeta: "completion-menu.multi-column-meta",
19 | Token.Menu.Completions.ProgressButton: "scrollbar.arrow", # best guess
20 | Token.Menu.Completions.ProgressBar: "scrollbar", # best guess
21 | Token.SelectedText: "selected",
22 | Token.SearchMatch: "search",
23 | Token.SearchMatch.Current: "search.current",
24 | Token.Toolbar: "bottom-toolbar",
25 | Token.Toolbar.Off: "bottom-toolbar.off",
26 | Token.Toolbar.On: "bottom-toolbar.on",
27 | Token.Toolbar.Search: "search-toolbar",
28 | Token.Toolbar.Search.Text: "search-toolbar.text",
29 | Token.Toolbar.System: "system-toolbar",
30 | Token.Toolbar.Arg: "arg-toolbar",
31 | Token.Toolbar.Arg.Text: "arg-toolbar.text",
32 | Token.Toolbar.Transaction.Valid: "bottom-toolbar.transaction.valid",
33 | Token.Toolbar.Transaction.Failed: "bottom-toolbar.transaction.failed",
34 | Token.Output.Header: "output.header",
35 | Token.Output.OddRow: "output.odd-row",
36 | Token.Output.EvenRow: "output.even-row",
37 | Token.Output.Null: "output.null",
38 | Token.Literal.String: "literal.string",
39 | Token.Literal.Number: "literal.number",
40 | Token.Keyword: "keyword",
41 | }
42 |
43 | # reverse dict for cli_helpers, because they still expect Pygments tokens.
44 | PROMPT_STYLE_TO_TOKEN = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()}
45 |
46 |
47 | def parse_pygments_style(token_name, style_object, style_dict):
48 | """Parse token type and style string.
49 |
50 | :param token_name: str name of Pygments token. Example: "Token.String"
51 | :param style_object: pygments.style.Style instance to use as base
52 | :param style_dict: dict of token names and their styles, customized to this cli
53 |
54 | """
55 | token_type = string_to_tokentype(token_name)
56 | try:
57 | other_token_type = string_to_tokentype(style_dict[token_name])
58 | return token_type, style_object.styles[other_token_type]
59 | except AttributeError:
60 | return token_type, style_dict[token_name]
61 |
62 |
63 | def style_factory(name, cli_style):
64 | try:
65 | style = pygments.styles.get_style_by_name(name)
66 | except ClassNotFound:
67 | style = pygments.styles.get_style_by_name("native")
68 |
69 | prompt_styles = []
70 | # prompt-toolkit used pygments tokens for styling before, switched to style
71 | # names in 2.0. Convert old token types to new style names, for backwards compatibility.
72 | for token in cli_style:
73 | if token.startswith("Token."):
74 | # treat as pygments token (1.0)
75 | token_type, style_value = parse_pygments_style(token, style, cli_style)
76 | if token_type in TOKEN_TO_PROMPT_STYLE:
77 | prompt_style = TOKEN_TO_PROMPT_STYLE[token_type]
78 | prompt_styles.append((prompt_style, style_value))
79 | else:
80 | # we don't want to support tokens anymore
81 | logger.error("Unhandled style / class name: %s", token)
82 | else:
83 | # treat as prompt style name (2.0). See default style names here:
84 | # https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py
85 | prompt_styles.append((token, cli_style[token]))
86 |
87 | override_style = Style([("bottom-toolbar", "noreverse")])
88 | return merge_styles(
89 | [style_from_pygments_cls(style), override_style, Style(prompt_styles)]
90 | )
91 |
92 |
93 | def style_factory_output(name, cli_style):
94 | try:
95 | style = pygments.styles.get_style_by_name(name).styles
96 | except ClassNotFound:
97 | style = pygments.styles.get_style_by_name("native").styles
98 |
99 | for token in cli_style:
100 | if token.startswith("Token."):
101 | token_type, style_value = parse_pygments_style(token, style, cli_style)
102 | style.update({token_type: style_value})
103 | elif token in PROMPT_STYLE_TO_TOKEN:
104 | token_type = PROMPT_STYLE_TO_TOKEN[token]
105 | style.update({token_type: cli_style[token]})
106 | else:
107 | # TODO: cli helpers will have to switch to ptk.Style
108 | logger.error("Unhandled style / class name: %s", token)
109 |
110 | class OutputStyle(PygmentsStyle):
111 | default_style = ""
112 | styles = style
113 |
114 | return OutputStyle
115 |
--------------------------------------------------------------------------------
/odbcli/preview.py:
--------------------------------------------------------------------------------
1 | from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightIncrementalSearchProcessor, HighlightSelectionProcessor, AppendAutoSuggestion
2 | from prompt_toolkit.buffer import Buffer
3 | from prompt_toolkit.layout.controls import BufferControl
4 | from prompt_toolkit.document import Document
5 | from prompt_toolkit.layout.menus import CompletionsMenu
6 | from prompt_toolkit.filters import has_focus, is_done
7 | from prompt_toolkit.layout.dimension import Dimension as D
8 | from prompt_toolkit.widgets import Button, TextArea, SearchToolbar, Box, Shadow, Frame
9 | from prompt_toolkit.layout.containers import Window, VSplit, HSplit, ConditionalContainer, FloatContainer, Float
10 | from prompt_toolkit.filters import Condition, is_done
11 | from prompt_toolkit.completion import Completer, Completion, ThreadedCompleter, CompleteEvent
12 | from prompt_toolkit.history import History, FileHistory, ThreadedHistory
13 | from prompt_toolkit.auto_suggest import AutoSuggestFromHistory, ThreadedAutoSuggest, AutoSuggest, Suggestion
14 | from cyanodbc import ConnectError, DatabaseError
15 | from cli_helpers.tabular_output import TabularOutputFormatter
16 | from functools import partial
17 | from typing import Callable, Iterable, List, Optional
18 | from logging import getLogger
19 | from os.path import expanduser
20 | from .completion.mssqlcompleter import MssqlCompleter
21 | from .filters import ShowPreview
22 | from .conn import connWrappers, connStatus, executionStatus
23 | from .config import config_location, ensure_dir_exists
24 |
25 | def object_to_identifier(obj: "myDBObject") -> str:
26 | # TODO: Verify connected
27 | sql_conn = obj.conn
28 | catalog = None
29 | schema = None
30 | if obj.parent is not None:
31 | if type(obj.parent).__name__ == "myDBSchema":
32 | schema = obj.parent.name
33 | elif type(obj.parent).__name__ == "myDBCatalog":
34 | catalog = obj.parent.name
35 | if obj.parent.parent is not None:
36 | if type(obj.parent.parent).__name__ == "myDBCatalog":
37 | catalog = obj.parent.parent.name
38 |
39 | if catalog:
40 | catalog = (sql_conn.quotechar + "%s" + sql_conn.quotechar) % catalog
41 | if schema:
42 | schema = (sql_conn.quotechar + "%s" + sql_conn.quotechar) % schema
43 | name = (sql_conn.quotechar + "%s" + sql_conn.quotechar) % obj.name
44 | identifier = ".".join(list(filter(None, [catalog, schema, name])))
45 |
46 | return identifier
47 |
48 |
49 | class PreviewCompleter(Completer):
50 | """ Wraps prompt_toolkit.Completer. The buffer that this completer is
51 | attached to only carries part of of the query: for example 'WHERE ...'.
52 | To complete the query effectively we need the complete preview query
53 | and this completer constructs a document object that carries the full
54 | query and feeds it to the wrapped completer.
55 | Rather than wrapping, probably should extend the class - however
56 | at this time as the completer class is fairly hacked up and not
57 | in a steady state, let's stay with the wrap.
58 | """
59 | def __init__(self, my_app: "sqlApp", completer: Completer) -> None:
60 | self.completer = completer
61 | self.my_app = my_app
62 |
63 | def get_completions(
64 | self, document: Document, complete_event: CompleteEvent
65 | ) -> Iterable[Completion]:
66 | obj = self.my_app.selected_object
67 | sql_conn = obj.conn
68 | identifier = object_to_identifier(obj)
69 | query = sql_conn.preview_query(
70 | name = identifier,
71 | obj_type = obj.otype,
72 | filter_query = document.text,
73 | limit = self.my_app.preview_limit_rows)
74 | if query is None:
75 | return []
76 |
77 | new_document = Document(text = query,
78 | cursor_position = query.find(document.text) + document.cursor_position)
79 | return self.completer.get_completions(new_document, complete_event)
80 |
81 | class PreviewHistory(FileHistory):
82 | def __init__(self, filename: str, my_app: "sqlApp") -> None:
83 | self.my_app = my_app
84 | super().__init__(filename)
85 |
86 | def store_string(self, string: str) -> None:
87 | """ Store filtering query in history file by
88 | adding the "[identifier]: " prefix
89 | """
90 | obj = self.my_app.selected_object
91 | identifier = object_to_identifier(obj)
92 | super().store_string(identifier + ": " + string)
93 |
94 | class PreviewSuggestFromHistory(AutoSuggest):
95 | """
96 | Give suggestions based on the lines in the history.
97 | """
98 | def __init__(self, my_app: "sqlApp") -> None:
99 | self.my_app = my_app
100 | super().__init__
101 |
102 | def get_suggestion(
103 | self, buffer: "Buffer", document: Document
104 | ) -> Optional[Suggestion]:
105 | """
106 | When looking for most recent suggestion look for one
107 | starting with the "[identifier]: " prefix
108 | """
109 | history = buffer.history
110 |
111 | # Consider only the last line for the suggestion.
112 | text = document.text.rsplit("\n", 1)[-1]
113 | # Only create a suggestion when this is not an empty line.
114 | if text.strip():
115 | obj = self.my_app.selected_object
116 | prefix = object_to_identifier(obj) + ": "
117 | # Find first matching line in history.
118 | for string in reversed(list(history.get_strings())):
119 | for line in reversed(string.splitlines()):
120 | loc = line.find(prefix)
121 | # Add one character for a space after SELECT identifier
122 | if loc >= 0 and line[loc + len(prefix):].startswith(text):
123 | return Suggestion(line[loc + len(prefix) + len(text) :])
124 |
125 | return None
126 |
127 | class PreviewBuffer(Buffer):
128 | def history_forward(self, count: int = 1) -> None:
129 | """ Disable searching through history on up/down arrow """
130 | return None
131 | def history_backward(self, count: int = 1) -> None:
132 | """ Disable searching through history on up/down arrow """
133 | return None
134 |
135 |
136 | class PreviewElement:
137 | """ Class to create the preview element. It contains two main methods:
138 | create_container: creates the main preview container. Intention is
139 | for this to land in a float.
140 | create_completion_float: creates the completion float in the preview
141 | container. Intention is for this to appear in the FloatContainer that
142 | hosts the main preview container float.
143 | """
144 | def __init__(self, my_app: "sqlApp"):
145 | self.my_app = my_app
146 | help_text = """
147 | Press Enter in the input box to page through the table.
148 | Alternatively, enter a filtering SQL statement and then press Enter
149 | to page through the results.
150 | """
151 | self.formatter = TabularOutputFormatter()
152 | self.completer = PreviewCompleter(
153 | my_app = self.my_app,
154 | completer = MssqlCompleter(
155 | smart_completion = True,
156 | get_conn = lambda: self.my_app.selected_object.conn))
157 |
158 | history_file = config_location() + 'preview_history'
159 | ensure_dir_exists(history_file)
160 | hist = PreviewHistory(
161 | my_app = self.my_app,
162 | filename = expanduser(history_file))
163 |
164 | self.input_buffer = PreviewBuffer(
165 | name = "previewbuffer",
166 | tempfile_suffix = ".sql",
167 | history = ThreadedHistory(hist),
168 | auto_suggest =
169 | ThreadedAutoSuggest(PreviewSuggestFromHistory(my_app)),
170 | completer = ThreadedCompleter(self.completer),
171 | # history = hist,
172 | # auto_suggest = PreviewSuggestFromHistory(my_app),
173 | # completer = self.completer,
174 | complete_while_typing = Condition(
175 | lambda: self.my_app.selected_object is not None and self.my_app.selected_object.conn.connected()
176 | ),
177 | multiline = False)
178 |
179 | input_control = BufferControl(
180 | buffer = self.input_buffer,
181 | include_default_input_processors = False,
182 | input_processors = [AppendAutoSuggestion()],
183 | preview_search = False)
184 |
185 | self.input_window = Window(input_control)
186 |
187 | search_buffer = Buffer(name = "previewsearchbuffer")
188 | self.search_field = SearchToolbar(search_buffer)
189 | self.output_field = TextArea(style = "class:preview-output-field",
190 | text = help_text,
191 | height = D(preferred = 50),
192 | search_field = self.search_field,
193 | wrap_lines = False,
194 | focusable = True,
195 | read_only = True,
196 | preview_search = True,
197 | input_processors = [
198 | ConditionalProcessor(
199 | processor=HighlightIncrementalSearchProcessor(),
200 | filter=has_focus("previewsearchbuffer")
201 | | has_focus(self.search_field.control),
202 | ),
203 | HighlightSelectionProcessor(),
204 | ])
205 |
206 | def refresh_results(window_height) -> bool:
207 | """ This method gets called when the app restarts after
208 | exiting for execution of preview query. It populates
209 | the output buffer with results from the fetch/query.
210 | """
211 | sql_conn = self.my_app.selected_object.conn
212 | if sql_conn.execution_status == executionStatus.FAIL:
213 | # Let's display the error message to the user
214 | output = sql_conn.execution_err
215 | else:
216 | crsr = sql_conn.cursor
217 | if crsr.description:
218 | cols = [col.name for col in crsr.description]
219 | else:
220 | cols = []
221 | if len(cols):
222 | res = sql_conn.fetch_from_cache(size = window_height - 4,
223 | wait = True)
224 | output = self.formatter.format_output(res, cols, format_name = "psql")
225 | output = "\n".join(output)
226 | else:
227 | output = "No rows returned\n"
228 |
229 | # Add text to output buffer.
230 | self.output_field.buffer.set_document(Document(
231 | text = output, cursor_position = 0), True)
232 |
233 | return True
234 |
235 | def accept(buff: Buffer) -> bool:
236 | """ This method gets called when the user presses enter/return
237 | in the filter box. It is interpreted as either 'execute query'
238 | or 'fetch next page of results' if filter query hasn't changed.
239 | """
240 | obj = self.my_app.selected_object
241 | sql_conn = obj.conn
242 | identifier = object_to_identifier(obj)
243 | query = sql_conn.preview_query(
244 | name = identifier,
245 | obj_type = obj.otype,
246 | filter_query = buff.text,
247 | limit = self.my_app.preview_limit_rows)
248 | if query is None:
249 | return True
250 |
251 | func = partial(refresh_results,
252 | window_height = self.output_field.window.render_info.window_height)
253 | if sql_conn.query != query:
254 | # Exit the app to execute the query
255 | self.my_app.application.exit(result = ["preview", query])
256 | self.my_app.application.pre_run_callables.append(func)
257 | else:
258 | # No need to exit let's just go and fetch
259 | func()
260 | return True # Keep filter text
261 |
262 | def cancel_handler() -> None:
263 | sql_conn = self.my_app.selected_object.conn
264 | sql_conn.close_cursor()
265 | self.input_buffer.text = ""
266 | self.output_field.buffer.set_document(Document(
267 | text = help_text, cursor_position = 0
268 | ), True)
269 | self.my_app.show_preview = False
270 | self.my_app.show_sidebar = True
271 | self.my_app.application.layout.focus(self.input_buffer)
272 | self.my_app.application.layout.focus("sidebarbuffer")
273 | return None
274 |
275 | self.input_buffer.accept_handler = accept
276 | self.cancel_button = Button(text = "Done", handler = cancel_handler)
277 |
278 | def create_completion_float(self) -> Float:
279 | return Float(
280 | xcursor = True,
281 | ycursor = True,
282 | transparent = True,
283 | attach_to_window = self.input_window,
284 | content = CompletionsMenu(
285 | scroll_offset = 1,
286 | max_height = 16,
287 | extra_filter = has_focus(self.input_buffer)))
288 |
289 | def create_container(self):
290 |
291 | container = HSplit(
292 | [
293 | Box(
294 | body = VSplit(
295 | [self.input_window, self.cancel_button],
296 | padding=1
297 | ),
298 | padding=1,
299 | style="class:preview-input-field"
300 | ),
301 | Window(height=1, char="-", style="class:preview-divider-line"),
302 | self.output_field,
303 | self.search_field,
304 | ])
305 |
306 | frame = Shadow(
307 | body = Frame(
308 | title = lambda: "Preview: " + self.my_app.selected_object.name,
309 | body = container,
310 | style="class:dialog.body",
311 | width = D(preferred = 180, min = 30),
312 | modal = True))
313 |
314 | return ConditionalContainer(
315 | content = frame,
316 | filter = ShowPreview(self.my_app) & ~is_done)
317 |
--------------------------------------------------------------------------------
/odbcli/sidebar.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import platform
3 | from cyanodbc import Connection
4 | from typing import List, Optional, Callable
5 | from logging import getLogger
6 | from asyncio import get_event_loop
7 | from threading import Thread, Lock
8 | from prompt_toolkit.layout.containers import HSplit, Window, ScrollOffsets, ConditionalContainer, Container
9 | from prompt_toolkit.formatted_text.base import StyleAndTextTuples
10 | from prompt_toolkit.formatted_text import fragment_list_width
11 | from prompt_toolkit.layout.controls import FormattedTextControl, BufferControl, UIContent
12 | from prompt_toolkit.layout.dimension import Dimension
13 | from prompt_toolkit.buffer import Buffer
14 | from prompt_toolkit.document import Document
15 | from prompt_toolkit.filters import is_done, renderer_height_is_known
16 | from prompt_toolkit.layout.margins import ScrollbarMargin
17 | from prompt_toolkit.mouse_events import MouseEvent
18 | from prompt_toolkit.lexers import Lexer
19 | from prompt_toolkit.widgets import SearchToolbar
20 | from prompt_toolkit.filters import Condition
21 | from .conn import sqlConnection
22 | from .filters import ShowSidebar
23 | from .utils import if_mousedown
24 | from .__init__ import __version__
25 |
26 | class myDBObject:
27 | def __init__(
28 | self,
29 | my_app: "sqlApp",
30 | conn: sqlConnection,
31 | name: str,
32 | otype: str,
33 | level: Optional[int] = 0,
34 | children: Optional[List["myDBObject"]] = None,
35 | parent: Optional["myDBObject"] = None,
36 | next_object: Optional["myDBObject"] = None
37 | ) -> None:
38 |
39 | self.my_app = my_app
40 | self.conn = conn
41 | self.children = children
42 | self.parent = parent
43 | self.next_object = next_object
44 | # Held while modifying children, parent, next_object
45 | # As some of thes operatins (expand) happen asynchronously
46 | self._lock = Lock()
47 | self.name = name
48 | self.otype = otype
49 | self.level = level
50 | self.selected: bool = False
51 |
52 | def _expand_internal(self) -> None:
53 | """
54 | Populates children and sets parent for children nodes
55 | """
56 | raise NotImplementedError()
57 |
58 | def expand(self) -> None:
59 | """
60 | Populates children and sets parent for children nodes
61 | """
62 | if self.children is not None:
63 | return None
64 |
65 | loop = get_event_loop()
66 | self.my_app.show_expanding_object = True
67 | self.my_app.application.invalidate()
68 | def _redraw_after_io():
69 | """ Callback, scheduled after threaded I/O
70 | completes """
71 | self.my_app.show_expanding_object = False
72 | self.my_app.obj_list_changed = True
73 | self.my_app.application.invalidate()
74 |
75 | def _run():
76 | """ Executes in a thread """
77 | self._expand_internal() # Blocking I/O
78 | loop.call_soon_threadsafe(_redraw_after_io)
79 |
80 | # (Don't use 'run_in_executor', because daemon is ideal here.
81 | t = Thread(target = _run, daemon = True)
82 | t.start()
83 |
84 | def collapse(self) -> None:
85 | """
86 | Populates children and sets parent for children nodes
87 | Note, we don't have to blow up the children; just redirect
88 | next_object. This way we re-query the database / force re-fresh
89 | which may be suboptimal. TODO: Codify not/refresh path
90 | """
91 | if self is not self.my_app.selected_object:
92 | return
93 | if self.children is not None:
94 | obj = self.children[len(self.children) - 1].next_object
95 | while obj.level > self.level:
96 | obj = obj.next_object
97 | with self._lock:
98 | self.next_object = obj
99 | self.children = None
100 | elif self.parent is not None:
101 | self.my_app.selected_object = self.parent
102 | self.parent.collapse()
103 |
104 | self.my_app.obj_list_changed = True
105 |
106 | def add_children(self, list_obj: List["myDBObject"]) -> None:
107 | lst = list(filter(lambda x: x.name != "", list_obj))
108 | if len(lst):
109 | with self._lock:
110 | self.children = lst
111 | for i in range(len(self.children) - 1):
112 | self.children[i].next_object = self.children[i + 1]
113 | self.children[len(self.children) - 1].next_object = self.next_object
114 | self.next_object = self.children[0]
115 |
116 | class myDBColumn(myDBObject):
117 | def _expand_internal(self) -> None:
118 | return None
119 |
120 | class myDBFunction(myDBObject):
121 | def _expand_internal(self) -> None:
122 | cat = "%"
123 | schema = "%"
124 | # https://docs.microsoft.com/en-us/sql/odbc/reference/syntax/sqlprocedurecolumns-function?view=sql-server-ver15
125 | # CatalogName cannot contain a string search pattern
126 |
127 | if self.parent is not None:
128 | if type(self.parent).__name__ == "myDBSchema":
129 | schema = self.conn.sanitize_search_string(self.parent.name)
130 | elif type(self.parent).__name__ == "myDBCatalog":
131 | cat = self.parent.name
132 | if self.parent.parent is not None:
133 | if type(self.parent.parent).__name__ == "myDBCatalog":
134 | cat = self.parent.parent.name
135 |
136 | res = self.conn.find_procedure_columns(
137 | catalog = cat,
138 | schema = schema,
139 | procedure = self.conn.sanitize_search_string(self.name),
140 | column = "%")
141 |
142 | lst = [myDBColumn(
143 | my_app = self.my_app,
144 | conn = self.conn,
145 | name = col.column,
146 | otype = col.type_name,
147 | parent = self,
148 | level = self.level + 1) for col in res]
149 |
150 | self.add_children(list_obj = lst)
151 | return None
152 |
153 | class myDBTable(myDBObject):
154 | def _expand_internal(self) -> None:
155 | cat = "%"
156 | schema = "%"
157 | # https://docs.microsoft.com/en-us/sql/odbc/reference/syntax/sqlcolumns-function?view=sql-server-ver15
158 | # CatalogName cannot contain a string search pattern
159 |
160 | if self.parent is not None:
161 | if type(self.parent).__name__ == "myDBSchema":
162 | schema = self.conn.sanitize_search_string(self.parent.name)
163 | elif type(self.parent).__name__ == "myDBCatalog":
164 | cat = self.parent.name
165 | if self.parent.parent is not None:
166 | if type(self.parent.parent).__name__ == "myDBCatalog":
167 | cat = self.parent.parent.name
168 |
169 | res = self.conn.find_columns(
170 | catalog = cat,
171 | schema = schema,
172 | table = self.name,
173 | column = "%")
174 |
175 | lst = [myDBColumn(
176 | my_app = self.my_app,
177 | conn = self.conn,
178 | name = col.column,
179 | otype = col.type_name,
180 | parent = self,
181 | level = self.level + 1) for col in res]
182 |
183 | self.add_children(list_obj = lst)
184 | return None
185 |
186 | class myDBSchema(myDBObject):
187 | def _expand_internal(self) -> None:
188 |
189 | cat = self.conn.sanitize_search_string(self.parent.name) if self.parent is not None else "%"
190 | res = self.conn.find_tables(
191 | catalog = cat,
192 | schema = self.conn.sanitize_search_string(self.name),
193 | table = "",
194 | type = "")
195 | resf = self.conn.find_procedures(
196 | catalog = cat,
197 | schema = self.conn.sanitize_search_string(self.name),
198 | procedure = "")
199 | tables = []
200 | views = []
201 | functions = []
202 | lst = []
203 | for table in res:
204 | if table.type.lower() == 'table':
205 | tables.append(table.name)
206 | if table.type.lower() == 'view':
207 | views.append(table.name)
208 | lst.append(myDBTable(
209 | my_app = self.my_app,
210 | conn = self.conn,
211 | name = table.name,
212 | otype = table.type.lower(),
213 | parent = self,
214 | level = self.level + 1))
215 | for func in resf:
216 | functions.append(func.name)
217 | lst.append(myDBFunction(
218 | my_app = self.my_app,
219 | conn = self.conn,
220 | name = func.name,
221 | otype = "function",
222 | parent = self,
223 | level = self.level + 1))
224 |
225 | self.conn.dbmetadata.extend_objects(
226 | catalog = self.conn.escape_name(self.parent.name) if self.parent else "",
227 | schema = self.conn.escape_name(self.name),
228 | names = self.conn.escape_names(tables),
229 | obj_type = "table")
230 | self.conn.dbmetadata.extend_objects(
231 | catalog = self.conn.escape_name(self.parent.name) if self.parent else "",
232 | schema = self.conn.escape_name(self.name),
233 | names = self.conn.escape_names(views),
234 | obj_type = "view")
235 | self.conn.dbmetadata.extend_objects(
236 | catalog = self.conn.escape_name(self.parent.name) if self.parent else "",
237 | schema = self.conn.escape_name(self.name),
238 | names = self.conn.escape_names(functions),
239 | obj_type = "function")
240 | self.add_children(list_obj = lst)
241 | return None
242 |
243 | class myDBCatalog(myDBObject):
244 | def _expand_internal(self) -> None:
245 | schemas = lst = []
246 | schemas = self.conn.list_schemas(
247 | catalog = self.conn.sanitize_search_string(self.name))
248 |
249 | if len(schemas) < 1 or all([s == "" for s in schemas]):
250 | res = self.conn.find_tables(
251 | catalog = self.conn.sanitize_search_string(self.name),
252 | schema = "",
253 | table = "",
254 | type = "")
255 | schemas = [r.schema for r in res]
256 |
257 | self.conn.dbmetadata.extend_schemas(
258 | catalog = self.conn.escape_name(self.name),
259 | names = self.conn.escape_names(schemas))
260 |
261 | if not all([s == "" for s in schemas]):
262 | # Schemas were found either having called list_schemas
263 | # or via the find_tables call
264 | lst = [myDBSchema(
265 | my_app = self.my_app,
266 | conn = self.conn,
267 | name = schema,
268 | otype = "schema",
269 | parent = self,
270 | level = self.level + 1) for schema in sorted(set(schemas))]
271 | elif len(schemas):
272 | # No schemas found; but if there are tables then these are direct
273 | # descendents, i.e. MySQL
274 | tables = []
275 | views = []
276 | lst = []
277 | for table in res:
278 | if table.type.lower() == 'table':
279 | tables.append(table.name)
280 | if table.type.lower() == 'view':
281 | views.append(table.name)
282 | lst.append(myDBTable(
283 | my_app = self.my_app,
284 | conn = self.conn,
285 | name = table.name,
286 | otype = table.type.lower(),
287 | parent = self,
288 | level = self.level + 1))
289 | self.conn.dbmetadata.extend_objects(
290 | catalog = self.conn.escape_name(self.name),
291 | schema = "", names = self.conn.escape_names(tables),
292 | obj_type = "table")
293 | self.conn.dbmetadata.extend_objects(
294 | catalog = self.conn.escape_name(self.name),
295 | schema = "", names = self.conn.escape_names(views),
296 | obj_type = "view")
297 |
298 | self.add_children(list_obj = lst)
299 | return None
300 |
301 |
302 | class myDBConn(myDBObject):
303 | def _expand_internal(self) -> None:
304 | if not self.conn.connected():
305 | return None
306 |
307 | lst = []
308 | cat_support = self.conn.catalog_support()
309 | if cat_support:
310 | rows = self.conn.list_catalogs()
311 | if len(rows):
312 | lst = [myDBCatalog(
313 | my_app = self.my_app,
314 | conn = self.conn,
315 | name = row,
316 | otype = "catalog",
317 | parent = self,
318 | level = self.level + 1) for row in rows]
319 | self.conn.dbmetadata.extend_catalogs(
320 | self.conn.escape_names(rows))
321 | else:
322 | res = self.conn.find_tables(
323 | catalog = "%",
324 | schema = "",
325 | table = "",
326 | type = "")
327 | schemas = [r.schema for r in res]
328 | self.conn.dbmetadata.extend_schemas(catalog = "",
329 | names = self.conn.escape_names(schemas))
330 | if not all([s == "" for s in schemas]):
331 | lst = [myDBSchema(
332 | my_app = self.my_app,
333 | conn = self.conn,
334 | name = schema,
335 | otype = "schema",
336 | parent = self,
337 | level = self.level + 1) for schema in sorted(set(schemas))]
338 | elif len(schemas):
339 | tables = []
340 | views = []
341 | lst = []
342 | for table in res:
343 | if table.type.lower() == 'table':
344 | tables.append(table.name)
345 | if table.type.lower() == 'view':
346 | views.append(table.name)
347 | lst.append(myDBTable(
348 | my_app = self.my_app,
349 | conn = self.conn,
350 | name = table.name,
351 | otype = table.type.lower(),
352 | parent = self,
353 | level = self.level + 1))
354 | self.conn.dbmetadata.extend_objects(catalog = "",
355 | schema = "", names = self.conn.escape_names(tables),
356 | obj_type = "table")
357 | self.conn.dbmetadata.extend_objects(catalog = "",
358 | schema = "", names = self.conn.escape_names(views),
359 | obj_type = "view")
360 | self.add_children(list_obj = lst)
361 | return None
362 |
363 | def sql_sidebar_help(my_app: "sqlApp"):
364 | """
365 | Create the `Layout` for the help text for the current item in the sidebar.
366 | """
367 | token = "class:sidebar.helptext"
368 |
369 | def get_current_description():
370 | """
371 | Return the description of the selected option.
372 | """
373 | obj = my_app.selected_object
374 | if obj is not None:
375 | return obj.name
376 | return ""
377 |
378 | def get_help_text():
379 | return [(token, get_current_description())]
380 |
381 | return ConditionalContainer(
382 | content=Window(
383 | FormattedTextControl(get_help_text), style=token, height=Dimension(min=3)
384 | ),
385 | filter = ~is_done
386 | & ShowSidebar(my_app)
387 | & Condition(
388 | lambda: not my_app.show_exit_confirmation
389 | ))
390 |
391 | def expanding_object_notification(my_app: "sqlApp"):
392 | """
393 | Create the `Layout` for the 'Expanding object' notification.
394 | """
395 |
396 | def get_text_fragments():
397 | # Show navigation info.
398 | return [("fg:red", "Expanding object ...")]
399 |
400 | return ConditionalContainer(
401 | content = Window(
402 | FormattedTextControl(get_text_fragments),
403 | style = "class:sidebar",
404 | width=Dimension.exact( 45 ),
405 | height=Dimension(max = 1),
406 | ),
407 | filter = ~is_done
408 | & ShowSidebar(my_app)
409 | & Condition(
410 | lambda: my_app.show_expanding_object
411 | ))
412 |
413 | def sql_sidebar_navigation():
414 | """
415 | Create the `Layout` showing the navigation information for the sidebar.
416 | """
417 |
418 | def get_text_fragments():
419 | # Show navigation info.
420 | return [
421 | ("class:sidebar.navigation", " "),
422 | ("class:sidebar.navigation.key", "[Up/Dn]"),
423 | ("class:sidebar.navigation", " "),
424 | ("class:sidebar.navigation.description", "Navigate"),
425 | ("class:sidebar.navigation", " "),
426 | ("class:sidebar.navigation.key", "[L/R]"),
427 | ("class:sidebar.navigation", " "),
428 | ("class:sidebar.navigation.description", "Expand/Collapse"),
429 | ("class:sidebar.navigation", "\n "),
430 | ("class:sidebar.navigation.key", "[Enter]"),
431 | ("class:sidebar.navigation", " "),
432 | ("class:sidebar.navigation.description", "Connect/Preview"),
433 | ]
434 |
435 | return Window(
436 | FormattedTextControl(get_text_fragments),
437 | style = "class:sidebar.navigation",
438 | width=Dimension.exact( 45 ),
439 | height=Dimension(max = 2),
440 | )
441 |
442 | def show_sidebar_button_info(my_app: "sqlApp") -> Container:
443 | """
444 | Create `Layout` for the information in the right-bottom corner.
445 | (The right part of the status bar.)
446 | """
447 |
448 | @if_mousedown
449 | def toggle_sidebar(mouse_event: MouseEvent) -> None:
450 | " Click handler for the menu. "
451 | my_app.show_sidebar = not my_app.show_sidebar
452 |
453 | # TO DO: app version rather than python
454 | version = sys.version_info
455 | tokens: StyleAndTextTuples = [
456 | ("class:status-toolbar.key", "[C-t]", toggle_sidebar),
457 | ("class:status-toolbar", " Object Browser", toggle_sidebar),
458 | ("class:status-toolbar", " - "),
459 | ("class:status-toolbar.cli-version", "odbcli %s" % __version__),
460 | ("class:status-toolbar", " "),
461 | ]
462 | width = fragment_list_width(tokens)
463 |
464 | def get_text_fragments() -> StyleAndTextTuples:
465 | # Python version
466 | return tokens
467 |
468 | return ConditionalContainer(
469 | content=Window(
470 | FormattedTextControl(get_text_fragments),
471 | style="class:status-toolbar",
472 | height=Dimension.exact(1),
473 | width=Dimension.exact(width),
474 | ),
475 | filter=~is_done
476 | & Condition(
477 | lambda: not my_app.show_exit_confirmation
478 | )
479 | & renderer_height_is_known
480 | )
481 |
482 | def sql_sidebar(my_app: "sqlApp") -> Window:
483 | """
484 | Create the `Layout` for the sidebar with the configurable objects.
485 | """
486 |
487 | @if_mousedown
488 | def expand_item(obj: "myDBObject") -> None:
489 | obj.expand()
490 |
491 | def tokenize_obj(obj: "myDBObject") -> StyleAndTextTuples:
492 | " Recursively build the token list "
493 | tokens: StyleAndTextTuples = []
494 | selected = obj is my_app.selected_object
495 | expanded = obj.children is not None
496 | connected = obj.otype == "Connection" and obj.conn.connected()
497 | active = my_app.active_conn is not None and my_app.active_conn is obj.conn and obj.level == 0
498 |
499 | act = ",active" if active else ""
500 | sel = ",selected" if selected else ""
501 | if len(obj.name) > 24 - 2 * obj.level:
502 | name_trim = obj.name[:24 - 2 * obj.level - 3] + "..."
503 | else:
504 | name_trim = ("%-" + str(24 - 2 * obj.level) + "s") % obj.name
505 |
506 | tokens.append(("class:sidebar.label" + sel + act, " >" if connected else " "))
507 | tokens.append(
508 | ("class:sidebar.label" + sel, " " * 2 * obj.level, expand_item)
509 | )
510 | tokens.append(
511 | ("class:sidebar.label" + sel + act,
512 | name_trim,
513 | expand_item)
514 | )
515 | tokens.append(("class:sidebar.status" + sel + act, " ", expand_item))
516 | tokens.append(("class:sidebar.status" + sel + act, "%+12s" % obj.otype, expand_item))
517 |
518 | if selected:
519 | tokens.append(("[SetCursorPosition]", ""))
520 |
521 | if expanded:
522 | tokens.append(("class:sidebar.status" + sel + act, "\/"))
523 | else:
524 | tokens.append(("class:sidebar.status" + sel + act, " <" if selected else " "))
525 |
526 | # Expand past the edge of the visible buffer to get an even panel
527 | tokens.append(("class:sidebar.status" + sel + act, " " * 10))
528 | return tokens
529 |
530 | search_buffer = Buffer(name = "sidebarsearchbuffer")
531 | search_field = SearchToolbar(
532 | search_buffer = search_buffer,
533 | ignore_case = True
534 | )
535 | def _buffer_pos_changed(buff):
536 | """ This callback gets executed after cursor position changes. Most
537 | of the time we register a key-press (up / down), we change the
538 | selected object and as a result of that the cursor changes. By that
539 | time we don't need to updat the selected object (cursor changed as
540 | a result of the selected object being updated). The one exception
541 | is when searching the sidebar buffer. When this happens the cursor
542 | moves ahead of the selected object. When that happens, here we
543 | update the selected object to follow suit.
544 | """
545 | if buff.document.cursor_position_row != my_app.selected_object_idx[0]:
546 | my_app.select(buff.document.cursor_position_row)
547 |
548 | sidebar_buffer = Buffer(
549 | name = "sidebarbuffer",
550 | read_only = True,
551 | on_cursor_position_changed = _buffer_pos_changed
552 | )
553 |
554 | class myLexer(Lexer):
555 | def __init__(self, *args, **kwargs):
556 | super().__init__(*args, **kwargs)
557 | self._obj_list = []
558 |
559 | def add_objects(self, objects: List):
560 | self._obj_list = objects
561 |
562 | def lex_document(self, document: Document) -> Callable[[int], StyleAndTextTuples]:
563 | def get_line(lineno: int) -> StyleAndTextTuples:
564 | # TODO: raise out-of-range exception
565 | return tokenize_obj(self._obj_list[lineno])
566 | return get_line
567 |
568 |
569 | sidebar_lexer = myLexer()
570 |
571 | class myControl(BufferControl):
572 |
573 | def move_cursor_down(self):
574 | my_app.select_next()
575 | # Need to figure out what do do here
576 | # AFAICT these are only called for the mouse handler
577 | # when events are otherwise not handled
578 | def move_cursor_up(self):
579 | my_app.select_previous()
580 |
581 | def mouse_handler(self, mouse_event: MouseEvent) -> "NotImplementedOrNone":
582 | """
583 | There is an intricate relationship between the cursor position
584 | in the sidebar document and which object is market as 'selected'
585 | in the linked list. Let's not muck that up by allowing the user
586 | to change the cursor position in the sidebar document with the mouse.
587 | """
588 | return NotImplemented
589 |
590 | def create_content(self, width: int, height: Optional[int]) -> UIContent:
591 | # Only traverse the obj_list if it has been expanded / collapsed
592 | if not my_app.obj_list_changed:
593 | self.buffer.cursor_position = my_app.selected_object_idx[1]
594 | return super().create_content(width, height)
595 |
596 | res = []
597 | obj = my_app.obj_list[0]
598 | res.append(obj)
599 | while obj.next_object is not my_app.obj_list[0]:
600 | res.append(obj.next_object)
601 | obj = obj.next_object
602 |
603 | self.lexer.add_objects(res)
604 | self.buffer.set_document(Document(
605 | text = "\n".join([a.name for a in res]), cursor_position = my_app.selected_object_idx[1]), True)
606 | # Reset obj_list_changed flag, now that we have had a chance to
607 | # regenerate the sidebar document content
608 | my_app.obj_list_changed = False
609 | return super().create_content(width, height)
610 |
611 |
612 |
613 | sidebar_control = myControl(
614 | buffer = sidebar_buffer,
615 | lexer = sidebar_lexer,
616 | search_buffer_control = search_field.control,
617 | focusable = True,
618 | )
619 |
620 | return HSplit([
621 | search_field,
622 | Window(
623 | sidebar_control,
624 | right_margins = [ScrollbarMargin(display_arrows = True)],
625 | style = "class:sidebar",
626 | width = Dimension.exact( 45 ),
627 | height = Dimension(min = 7, preferred = 33),
628 | scroll_offsets = ScrollOffsets(top = 1, bottom = 1)),
629 | Window(
630 | height = Dimension.exact(1),
631 | char = "\u2500",
632 | style = "class:sidebar,separator",
633 | ),
634 | expanding_object_notification(my_app),
635 | sql_sidebar_navigation()])
636 |
--------------------------------------------------------------------------------
/odbcli/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | For internal use only.
3 | """
4 | import re
5 | from typing import Callable, TypeVar, cast
6 | import os
7 | from platform import system
8 |
9 | from prompt_toolkit.mouse_events import MouseEvent, MouseEventType
10 |
11 | __all__ = [
12 | "has_unclosed_brackets",
13 | "get_jedi_script_from_document",
14 | "document_is_multiline_python",
15 | ]
16 |
17 | def has_unclosed_brackets(text: str) -> bool:
18 | """
19 | Starting at the end of the string. If we find an opening bracket
20 | for which we didn't had a closing one yet, return True.
21 | """
22 | stack = []
23 |
24 | # Ignore braces inside strings
25 | text = re.sub(r"""('[^']*'|"[^"]*")""", "", text) # XXX: handle escaped quotes.!
26 |
27 | for c in reversed(text):
28 | if c in "])}":
29 | stack.append(c)
30 |
31 | elif c in "[({":
32 | if stack:
33 | if (
34 | (c == "[" and stack[-1] == "]")
35 | or (c == "{" and stack[-1] == "}")
36 | or (c == "(" and stack[-1] == ")")
37 | ):
38 | stack.pop()
39 | else:
40 | # Opening bracket for which we didn't had a closing one.
41 | return True
42 |
43 | return False
44 |
45 |
46 | def get_jedi_script_from_document(document, locals, globals):
47 | import jedi # We keep this import in-line, to improve start-up time.
48 |
49 | # Importing Jedi is 'slow'.
50 |
51 | try:
52 | return jedi.Interpreter(
53 | document.text,
54 | column=document.cursor_position_col,
55 | line=document.cursor_position_row + 1,
56 | path="input-text",
57 | namespaces=[locals, globals],
58 | )
59 | except ValueError:
60 | # Invalid cursor position.
61 | # ValueError('`column` parameter is not in a valid range.')
62 | return None
63 | except AttributeError:
64 | # Workaround for #65: https://github.com/jonathanslenders/python-prompt-toolkit/issues/65
65 | # See also: https://github.com/davidhalter/jedi/issues/508
66 | return None
67 | except IndexError:
68 | # Workaround Jedi issue #514: for https://github.com/davidhalter/jedi/issues/514
69 | return None
70 | except KeyError:
71 | # Workaroud for a crash when the input is "u'", the start of a unicode string.
72 | return None
73 | except Exception:
74 | # Workaround for: https://github.com/jonathanslenders/ptpython/issues/91
75 | return None
76 |
77 |
78 | _multiline_string_delims = re.compile("""[']{3}|["]{3}""")
79 |
80 |
81 | def document_is_multiline_python(document):
82 | """
83 | Determine whether this is a multiline Python document.
84 | """
85 |
86 | def ends_in_multiline_string() -> bool:
87 | """
88 | ``True`` if we're inside a multiline string at the end of the text.
89 | """
90 | delims = _multiline_string_delims.findall(document.text)
91 | opening = None
92 | for delim in delims:
93 | if opening is None:
94 | opening = delim
95 | elif delim == opening:
96 | opening = None
97 | return bool(opening)
98 |
99 | if "\n" in document.text or ends_in_multiline_string():
100 | return True
101 |
102 | def line_ends_with_colon() -> bool:
103 | return document.current_line.rstrip()[-1:] == ":"
104 |
105 | # If we just typed a colon, or still have open brackets, always insert a real newline.
106 | if (
107 | line_ends_with_colon()
108 | or (
109 | document.is_cursor_at_the_end
110 | and has_unclosed_brackets(document.text_before_cursor)
111 | )
112 | or document.text.startswith("@")
113 | ):
114 | return True
115 |
116 | # If the character before the cursor is a backslash (line continuation
117 | # char), insert a new line.
118 | elif document.text_before_cursor[-1:] == "\\":
119 | return True
120 |
121 | return False
122 |
123 |
124 | _T = TypeVar("_T", bound=Callable[[MouseEvent], None])
125 |
126 |
127 | def if_mousedown(handler: _T) -> _T:
128 | """
129 | Decorator for mouse handlers.
130 | Only handle event when the user pressed mouse down.
131 |
132 | (When applied to a token list. Scroll events will bubble up and are handled
133 | by the Window.)
134 | """
135 |
136 | def handle_if_mouse_down(mouse_event: MouseEvent):
137 | if mouse_event.event_type == MouseEventType.MOUSE_DOWN:
138 | return handler(mouse_event)
139 | else:
140 | return NotImplemented
141 |
142 | return cast(_T, handle_if_mouse_down)
143 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 | import os
3 | import re
4 |
5 | with open("README.md", "r") as fh:
6 | long_description = fh.read()
7 |
8 | def get_version(package):
9 | """
10 | Return package version as listed in `__version__` in `__init__.py`.
11 | """
12 | path = os.path.join(os.path.dirname(__file__), package, "__init__.py")
13 | with open(path, "rb") as f:
14 | init_py = f.read().decode("utf-8")
15 | return re.search("__version__ = ['\"]([^'\"]+)['\"]", init_py).group(1)
16 |
17 |
18 | install_requirements = [
19 | 'cyanodbc >= 0.0.3',
20 | 'prompt_toolkit >= 3.0.5',
21 | 'Pygments>=2.6.1',
22 | 'sqlparse >= 0.3.1',
23 | 'configobj >= 5.0.6',
24 | 'click >= 7.1.2',
25 | 'cli_helpers >= 2.0.1'
26 | ]
27 |
28 | setuptools.setup(
29 | name = "odbcli",
30 | version = get_version("odbcli"),
31 | author = "Oliver Gjoneski",
32 | author_email = "ogjoneski@gmail.com",
33 | description = "ODBC Client",
34 | license = 'BSD-3',
35 | long_description = long_description,
36 | long_description_content_type = "text/markdown",
37 | install_requires = install_requirements,
38 | url = "https://github.com/pypa/odbc-cli",
39 | scripts=[
40 | 'odbc-cli'
41 | ],
42 | packages = setuptools.find_packages(),
43 | include_package_data = True,
44 | classifiers=[
45 | "Programming Language :: Python :: 3",
46 | "License :: OSI Approved :: BSD License",
47 | "Operating System :: OS Independent",
48 | ],
49 | # As python prompt toolkit
50 | python_requires = '>=3.6.1',
51 | )
52 |
--------------------------------------------------------------------------------
/tests/test_dbmetadata.py:
--------------------------------------------------------------------------------
1 | from odbcli.dbmetadata import DbMetadata
2 | import pytest
3 |
4 | def test_catalogs():
5 | db = DbMetadata()
6 | cats = ["a", "b", "c", ""]
7 | res = db.get_catalogs()
8 | assert res is None
9 | db.extend_catalogs(cats)
10 | res = db.get_catalogs(obj_type = "table")
11 | assert res == cats
12 | res = db.get_catalogs(obj_type = "view")
13 | assert res == cats
14 | res = db.get_catalogs(obj_type = "function")
15 | assert res == cats
16 |
17 | def test_schemas():
18 | db = DbMetadata()
19 | cats = ["a", "b", "c", ""]
20 | schemas = ["A", "B", "C", ""]
21 | db.extend_catalogs(cats)
22 | res = db.get_schemas(catalog = "d")
23 | assert res is None
24 | res = db.get_schemas(catalog = "a")
25 | assert res == []
26 | db.extend_schemas(catalog = "d", names = schemas)
27 | db.extend_schemas(catalog = "", names = schemas)
28 | res = db.get_catalogs(obj_type = "table")
29 | assert "d" in res
30 | res = db.get_schemas(catalog = "d")
31 | assert res == schemas
32 | res = db.get_schemas(catalog = "")
33 | assert res == schemas
34 |
35 | def test_objects():
36 | db = DbMetadata()
37 | cats = ["a", "b", "c", ""]
38 | schemas = ["A", "B", "C", ""]
39 | tables = ["t1", "t2", "t3"]
40 | views = ["v1", "v2", "v3"]
41 | db.extend_catalogs(cats)
42 | db.extend_schemas(catalog = "a", names = schemas)
43 | res = db.get_objects(catalog = "a", schema = "D")
44 | assert res is None
45 | res = db.get_objects(catalog = "a", schema = "A")
46 | assert res == []
47 | db.extend_objects(catalog = "a", schema = "A", names = tables, obj_type = "table")
48 | db.extend_objects(catalog = "a", schema = "D", names = tables, obj_type = "table")
49 | res = db.get_objects(catalog = "a", schema = "A")
50 | assert res == tables
51 | res = db.get_objects(catalog = "a", schema = "D")
52 | assert res == tables
53 | res = db.get_objects(catalog = "a", schema = "D", obj_type = "view")
54 | assert res == []
55 | db.extend_objects(catalog = "a", schema = "A", names = views, obj_type = "view")
56 | res = db.get_objects(catalog = "a", schema = "A", obj_type = "view")
57 | assert res == views
58 | res = db.get_objects(catalog = "a", schema = "A", obj_type = "table")
59 | assert res == tables
60 |
--------------------------------------------------------------------------------
/tests/test_parseutils.py:
--------------------------------------------------------------------------------
1 | from odbcli.completion.parseutils.tables import extract_table_identifiers, TableReference
2 | from sqlparse import parse
3 | import pytest
4 |
5 | def test_get_table_identifiers():
6 | qry = "SELECT a.col1, b.col2 " \
7 | "FROM abc.def.ghi AS a " \
8 | "INNER JOIN jkl.mno.pqr AS b ON a.id_one = b.id_two"
9 | parsed = parse(qry)[0]
10 | res = list(extract_table_identifiers(parsed))
11 | expected = [
12 | TableReference(None, "a", "col1", None, False),
13 | TableReference(None, "b", "col2", None, False),
14 | TableReference("abc", "def", "ghi", "a", False),
15 | TableReference("jkl", "mno", "pqr", "b", False),
16 | ]
17 |
18 | assert res == expected
19 |
--------------------------------------------------------------------------------
/tests/test_sqlcompletion.py:
--------------------------------------------------------------------------------
1 | from odbcli.completion.sqlcompletion import SqlStatement
2 | import pytest
3 |
4 |
5 | @pytest.mark.parametrize(
6 | "before_cursor, expected",
7 | [
8 | (" ", (None, None)),
9 | ("abc", (None, None)),
10 | ("abc.", (None, "abc")),
11 | ("abc.def", (None, "abc")),
12 | ("abc.def.", ("abc", "def")),
13 | ("abc.def.ghi", ("abc", "def"))
14 | ],
15 | )
16 | def test_get_identifier_parents(before_cursor, expected):
17 | stmt = SqlStatement(
18 | full_text = "SELECT * FROM abc.def.ghi",
19 | text_before_cursor = before_cursor)
20 |
21 | assert stmt.get_identifier_parents() == expected
22 |
--------------------------------------------------------------------------------