├── tests └── linters.py ├── setup.py ├── README.md └── pgls ├── __main__.py ├── linter.py └── server.py /tests/linters.py: -------------------------------------------------------------------------------- 1 | from pgls.linter import LINTERS 2 | from unittest import TestCase 3 | 4 | 5 | def make_test_case(linter): 6 | pass 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='pgls', 5 | version='0.0.1', 6 | description='Language Server for PostgreSQL SQL development', 7 | packages=find_packages(), 8 | entry_points={ 9 | 'console_scripts': [ 10 | 'pgls = pgls.__main__:main' 11 | ] 12 | }, 13 | install_requires=['psycopg2', 'pglast', 'sqlalchemy', 'pygls'] 14 | ) 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | PostgreSQL language server 2 | ========================== 3 | 4 | This project provides a language server implementation (https://langserver.org) for PostgreSQL projects. 5 | It is currently in development, but targets the following features: 6 | - code completion, 7 | - diagnostics (missing tables/columns, invalid joins), 8 | - automatic query refactoring (extract to sub-query, inline sub-query...) 9 | 10 | This is an early draft, more work will be published here in the next months. 11 | 12 | You are very welcome to try this out, find and fix bugs, submit PRs... 13 | -------------------------------------------------------------------------------- /pgls/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | from pgls.server import pg_language_server 5 | 6 | logging.basicConfig(filename="pygls.log", level=logging.DEBUG, filemode="w") 7 | 8 | 9 | def add_arguments(parser): 10 | parser.description = "simple json server example" 11 | 12 | parser.add_argument( 13 | "--tcp", action="store_true", 14 | help="Use TCP server instead of stdio" 15 | ) 16 | parser.add_argument( 17 | "--host", default="127.0.0.1", 18 | help="Bind to this address" 19 | ) 20 | parser.add_argument( 21 | "--port", type=int, default=2087, 22 | help="Bind to this port" 23 | ) 24 | 25 | def main(): 26 | parser = argparse.ArgumentParser() 27 | add_arguments(parser) 28 | args = parser.parse_args() 29 | if args.tcp: 30 | pg_language_server.start_tcp(args.host, args.port) 31 | else: 32 | pg_language_server.start_io() 33 | 34 | 35 | if __name__ == '__main__': 36 | main() 37 | -------------------------------------------------------------------------------- /pgls/linter.py: -------------------------------------------------------------------------------- 1 | from pygls.types import Range, Position, Diagnostic, DiagnosticSeverity 2 | from pglast import Node 3 | 4 | LINTERS = {} 5 | 6 | 7 | class Linter: 8 | 9 | def __init__(self, code, name, fun): 10 | self.code = code 11 | self.name = name or fun.__name__ 12 | self.fun = fun 13 | self.description = self.fun.__doc__ 14 | 15 | def __call__(self, *args, **kwargs): 16 | return self.fun(*args, **kwargs) 17 | 18 | @property 19 | def severity(self): 20 | if self.code.startswith('W'): 21 | return DiagnosticSeverity.Warning 22 | if self.code.startswith('E'): 23 | return DiagnosticSeverity.Error 24 | if self.code.startswith('H'): 25 | return DiagnosticSeverity.Hint 26 | if self.code.startswith('I'): 27 | return DiagnosticSeverity.Information 28 | 29 | 30 | class linter: 31 | 32 | def __init__(self, code, name=None): 33 | self.name = name 34 | self.code = code 35 | 36 | def __call__(self, fun): 37 | if self.code in LINTERS: 38 | raise KeyError('Linter %s already exists' % self.code) 39 | LINTERS[self.code] = linter = Linter(self.code, self.name, fun) 40 | return linter 41 | 42 | 43 | class LinterContext: 44 | 45 | def __init__(self, search_path): 46 | self.search_path = search_path 47 | 48 | 49 | def _make_diagnostic(linter, node, message): 50 | start_pos = Position(node.parse_tree.loc) 51 | end_pos = Position(node.parse_tree.loc + 1) 52 | return Diagnostic( 53 | Range(start_pos, end_pos), 54 | message=message, 55 | code=linter.code, 56 | severity=linter.severity) 57 | 58 | 59 | def lint(statement, metadata, context): 60 | statement = Node(statement) 61 | for linter in LINTERS.values(): 62 | yield from (_make_diagnostic(linter, node, message) 63 | for node, message in linter(statement, metadata, context)) 64 | 65 | 66 | @linter('WDS0001') 67 | def dml_missing_where_clause(statement, metadata, context): 68 | print(statement) 69 | if statement.node_tag in ('DeleteStmt', 'UpdateStmt'): 70 | if not statement.whereClause: 71 | yield statement, 'Missing WHERE clause in %s' % statement.node_tag 72 | -------------------------------------------------------------------------------- /pgls/server.py: -------------------------------------------------------------------------------- 1 | import json 2 | import traceback 3 | 4 | from pygls.features import (COMPLETION, TEXT_DOCUMENT_DID_CHANGE, 5 | TEXT_DOCUMENT_DID_CLOSE, TEXT_DOCUMENT_DID_OPEN, 6 | TEXT_DOCUMENT_DID_SAVE) 7 | from pygls.server import LanguageServer 8 | from pygls.types import (CompletionItem, CompletionList, CompletionParams, 9 | ConfigurationItem, ConfigurationParams, Diagnostic, 10 | DiagnosticSeverity, 11 | DidChangeTextDocumentParams, 12 | DidCloseTextDocumentParams, DidOpenTextDocumentParams, 13 | DidSaveTextDocumentParams, 14 | Location, 15 | MessageType, Position, Range, Registration, 16 | RegistrationParams, Unregistration, 17 | UnregistrationParams) 18 | from pygls.protocol import LanguageServerProtocol, logger as protocollogger 19 | from pglast import parse_sql 20 | from pglast.parser import ParseError 21 | 22 | from pgls.linter import lint 23 | 24 | 25 | class PgLanguageServer(LanguageServer): 26 | pass 27 | 28 | 29 | class JSONEncoder(json.JSONEncoder): 30 | 31 | def default(self, obj): 32 | return {key: value for key, value in obj.__dict__.items() 33 | if value is not None} 34 | 35 | 36 | def char_pos_to_position(buf, position): 37 | seekpos = 0 38 | position = position - 1 39 | for lineno, line in enumerate(buf.split('\n')): 40 | if seekpos + (len(line) + 1) > position: 41 | return Position(lineno, position - seekpos) 42 | seekpos += len(line) + 1 43 | 44 | 45 | class PgLanguageProtocol(LanguageServerProtocol): 46 | 47 | def _send_data(self, data): 48 | """Sends data to the client.""" 49 | if not data: 50 | return 51 | 52 | try: 53 | body = json.dumps(data, cls=JSONEncoder) 54 | content_length = len(body.encode(self.CHARSET)) if body else 0 55 | 56 | response = ( 57 | 'Content-Length: {}\r\n' 58 | 'Content-Type: {}; charset={}\r\n\r\n' 59 | '{}'.format(content_length, 60 | self.CONTENT_TYPE, 61 | self.CHARSET, 62 | body) 63 | ) 64 | 65 | protocollogger.info('Sending data: {}'.format(body)) 66 | 67 | self.transport.write(response.encode(self.CHARSET)) 68 | except Exception: 69 | protocollogger.error(traceback.format_exc()) 70 | 71 | 72 | pg_language_server = PgLanguageServer(protocol_cls=PgLanguageProtocol) 73 | 74 | 75 | def _validate(ls, params): 76 | ls.show_message_log('Linting SQL...') 77 | text_doc = ls.workspace.get_document(params.textDocument.uri) 78 | source = text_doc.source 79 | diagnostics = _validate_sql(source, params.textDocument.uri) 80 | ls.publish_diagnostics(text_doc.uri, diagnostics) 81 | 82 | 83 | def _validate_sql(sqlfile, uri): 84 | diagnostics = [] 85 | try: 86 | statements = parse_sql(sqlfile) 87 | except ParseError as e: 88 | pos = char_pos_to_position(sqlfile, e.location) 89 | diagnostics.append(Diagnostic(Range(pos, pos), 90 | message=e.args[0], 91 | severity=DiagnosticSeverity.Error, 92 | source=type(pg_language_server).__name__)) 93 | return diagnostics 94 | for statement in statements: 95 | for diag in lint(statement, None, None): 96 | diagnostics.append(diag) 97 | return diagnostics 98 | 99 | 100 | @pg_language_server.feature(TEXT_DOCUMENT_DID_CHANGE) 101 | def did_change(ls, params: DidChangeTextDocumentParams): 102 | """Text document did change notification.""" 103 | _validate(ls, params) 104 | 105 | 106 | @pg_language_server.feature(TEXT_DOCUMENT_DID_OPEN) 107 | async def did_open(ls, params: DidOpenTextDocumentParams): 108 | """Text document did open notification.""" 109 | ls.show_message('Open SQL FILE...') 110 | _validate(ls, params) 111 | 112 | 113 | @pg_language_server.feature(TEXT_DOCUMENT_DID_SAVE) 114 | async def did_save(ls, params: DidSaveTextDocumentParams): 115 | ls.show_message('Saving SQL FILE...') 116 | _validate(ls, params) 117 | --------------------------------------------------------------------------------