├── src ├── nebulo │ ├── gql │ │ ├── __init__.py │ │ ├── relay │ │ │ ├── __init__.py │ │ │ ├── page_info.py │ │ │ ├── cursor.py │ │ │ └── node_interface.py │ │ ├── convert │ │ │ ├── __init__.py │ │ │ ├── delete.py │ │ │ ├── connection.py │ │ │ ├── create.py │ │ │ ├── table.py │ │ │ ├── update.py │ │ │ ├── column.py │ │ │ └── function.py │ │ ├── resolve │ │ │ ├── __init__.py │ │ │ ├── resolvers │ │ │ │ ├── __init__.py │ │ │ │ ├── default.py │ │ │ │ ├── claims.py │ │ │ │ └── asynchronous.py │ │ │ └── transpile │ │ │ │ ├── __init__.py │ │ │ │ └── mutation_builder.py │ │ ├── sqla_to_gql.py │ │ ├── parse_info.py │ │ └── alias.py │ ├── sql │ │ ├── __init__.py │ │ ├── reflection │ │ │ ├── __init__.py │ │ │ ├── names.py │ │ │ ├── constraint_comments.py │ │ │ ├── manager.py │ │ │ ├── types.py │ │ │ ├── views.py │ │ │ └── function.py │ │ ├── statement_helpers.py │ │ ├── table_base.py │ │ ├── composite.py │ │ └── inspect.py │ ├── server │ │ ├── __init__.py │ │ ├── routes │ │ │ ├── __init__.py │ │ │ ├── graphiql.py │ │ │ └── graphql.py │ │ ├── app.py │ │ ├── exception.py │ │ ├── jwt.py │ │ └── starlette.py │ ├── __init__.py │ ├── text_utils │ │ ├── __init__.py │ │ ├── pluralize.py │ │ ├── base64.py │ │ └── casing.py │ ├── exceptions.py │ ├── lexer.py │ ├── env.py │ ├── cli.py │ └── config.py └── test │ ├── lexer_test.py │ ├── parse_info_test.py │ ├── exclude_table_test.py │ ├── rename_column_test.py │ ├── rename_table_test.py │ ├── exclude_fkey_test.py │ ├── cli_test.py │ ├── exclude_column_test.py │ ├── integration_array_test.py │ ├── integration_enum_test.py │ ├── integration_composite_test.py │ ├── reflect_constraint_comment_test.py │ ├── create_test.py │ ├── delete_test.py │ ├── jwt_function_test.py │ ├── update_test.py │ ├── node_id_test.py │ ├── view_test.py │ ├── function_returns_row_test.py │ ├── integration_one_basic_test.py │ ├── integration_many_basic_test.py │ ├── condition_test.py │ ├── function_test.py │ ├── reflection_test.py │ ├── app_test.py │ ├── integration_relationship_test.py │ ├── conftest.py │ └── pagination_test.py ├── MANIFEST.in ├── docs ├── images │ └── graphiql.png ├── sqlalchemy_interop.md ├── row_level_security.md ├── supported_types.md ├── supported_entities.md ├── Makefile ├── make.bat ├── authentication.md ├── cli.md ├── comment_directives.md ├── performance.md ├── public_api.md ├── index.md └── quickstart.md ├── pytest.ini ├── .flake8 ├── .isort.cfg ├── .coveragerc ├── pyproject.toml ├── mypy.ini ├── Dockerfile ├── examples ├── sqla_interop │ ├── README.md │ └── app.py └── row_level_security │ ├── README.md │ └── setup.sql ├── .github └── workflows │ ├── pre-commit_hooks.yaml │ └── test.yml ├── mkdocs.yml ├── LICENSE ├── .pre-commit-config.yaml ├── setup.py ├── .gitignore └── README.md /src/nebulo/gql/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/nebulo/sql/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/nebulo/gql/relay/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/nebulo/server/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/nebulo/gql/convert/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/nebulo/gql/resolve/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/nebulo/sql/reflection/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/nebulo/gql/resolve/resolvers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/nebulo/gql/resolve/transpile/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | prune src/test 2 | include src/nebulo/static/* 3 | -------------------------------------------------------------------------------- /src/test/lexer_test.py: -------------------------------------------------------------------------------- 1 | def test_lexer_import(): 2 | pass 3 | -------------------------------------------------------------------------------- /src/nebulo/__init__.py: -------------------------------------------------------------------------------- 1 | VERSION = "0.3.1" 2 | 3 | __all__ = ["VERSION"] 4 | -------------------------------------------------------------------------------- /docs/images/graphiql.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/olirice/nebulo/HEAD/docs/images/graphiql.png -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = -p no:warnings src/test 3 | #addopts = -p no:warnings --cov=nebulo src/test 4 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, E501, W503, F403, F401, C0330 3 | max-line-length = 88 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4,B9 6 | -------------------------------------------------------------------------------- /src/nebulo/sql/statement_helpers.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import literal_column 2 | 3 | 4 | def literal_string(text): 5 | return literal_column(f"'{text}'") 6 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | known_third_party = appdirs,click,databases,flupy,graphql,inflect,jwt,nebulo,pygments,pytest,setuptools,sqlalchemy,starlette,typing_extensions,uvicorn 3 | -------------------------------------------------------------------------------- /src/nebulo/text_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .casing import camel_to_snake, snake_to_camel 2 | from .pluralize import to_plural 3 | 4 | __all__ = ["to_plural", "snake_to_camel", "camel_to_snake"] 5 | -------------------------------------------------------------------------------- /src/nebulo/exceptions.py: -------------------------------------------------------------------------------- 1 | class NebuloException(Exception): 2 | """Base exception for nebulo package""" 3 | 4 | 5 | class SQLParseError(NebuloException): 6 | """An entity could not be parsed""" 7 | -------------------------------------------------------------------------------- /src/nebulo/server/routes/__init__.py: -------------------------------------------------------------------------------- 1 | from .graphiql import get_graphiql_route 2 | from .graphql import get_graphql_route 3 | 4 | __all__ = [ 5 | "get_graphiql_route", 6 | "get_graphql_route", 7 | ] 8 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | exclude_lines = 3 | pragma: no cover 4 | if TYPE_CHECKING: 5 | if typing.TYPE_CHECKING 6 | except ImportError 7 | 8 | [run] 9 | omit = 10 | src/main/python/nebulo/lexer.py 11 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | exclude = ''' 4 | /( 5 | \.git 6 | | \.hg 7 | | \.mypy_cache 8 | | \.tox 9 | | \.venv 10 | | _build 11 | | buck-out 12 | | build 13 | | dist 14 | )/ 15 | ''' 16 | -------------------------------------------------------------------------------- /src/nebulo/sql/table_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | A base class to derive sql tables from 3 | """ 4 | 5 | from sqlalchemy import Table 6 | from typing_extensions import Protocol, runtime_checkable 7 | 8 | 9 | @runtime_checkable 10 | class TableProtocol(Protocol): 11 | __table__: Table 12 | -------------------------------------------------------------------------------- /docs/sqlalchemy_interop.md: -------------------------------------------------------------------------------- 1 | # SQLAlchemy Interop Example 2 | 3 | An example showing how to add a nebulo-powered API to an existing SQLAlchemy project can be found at the link below. 4 | 5 | [SQLAlchemy interop example project](https://github.com/olirice/nebulo/tree/master/examples/sqla_interop) 6 | 7 | 8 | -------------------------------------------------------------------------------- /docs/row_level_security.md: -------------------------------------------------------------------------------- 1 | # Row Level Security Example 2 | 3 | The following is a complete example showing how to leverage PostgreSQL row level security with JWTs to restrict data access. 4 | 5 | [row level security example project](https://github.com/olirice/nebulo/tree/master/examples/row_level_security) 6 | 7 | 8 | -------------------------------------------------------------------------------- /src/nebulo/server/app.py: -------------------------------------------------------------------------------- 1 | from nebulo.config import Config 2 | from nebulo.server.starlette import create_app 3 | 4 | APP = create_app( 5 | schema=Config.SCHEMA, 6 | connection=Config.CONNECTION, 7 | jwt_identifier=Config.JWT_IDENTIFIER, 8 | jwt_secret=Config.JWT_SECRET, 9 | default_role=Config.DEFAULT_ROLE, 10 | ) 11 | -------------------------------------------------------------------------------- /docs/supported_types.md: -------------------------------------------------------------------------------- 1 | # Supported Types 2 | 3 | There is first class support for the following PG types: 4 | 5 | - Boolean 6 | - Integer 7 | - BigInteger 8 | - Float 9 | - Numeric (casts to float) 10 | - Text 11 | - Timestamp (as ISO 8601 string) 12 | - UUID 13 | - Composite (excludes nested composites) 14 | - Enum 15 | 16 | Other, unknown types, default to a String representation. 17 | -------------------------------------------------------------------------------- /src/test/parse_info_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from nebulo.gql.parse_info import ASTNode 3 | 4 | 5 | def test_astnode_get_subfield_alias_exception(): 6 | with pytest.raises(Exception): 7 | # Calling instance method on class 8 | # No instance variables used before the thing 9 | # we're trying to test 10 | ASTNode.get_subfield_alias(None, path=[]) # type: ignore 11 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | plugins = sqlmypy 3 | 4 | follow_imports = skip 5 | strict_optional = True 6 | warn_redundant_casts = True 7 | warn_unused_ignores = False 8 | disallow_any_generics = False 9 | check_untyped_defs = True 10 | no_implicit_reexport = True 11 | ignore_missing_imports = True 12 | 13 | 14 | # Strict Mode 15 | disallow_untyped_defs = False 16 | 17 | [nebulo.gql.query_builder] 18 | ignore_errors=True 19 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8.2 2 | 3 | RUN apt-get update 4 | RUN pip3 install -U pip 5 | RUN pip3 install -U setuptools 6 | RUN pip3 install "nebulo==0.0.7" 7 | 8 | ENV DB_CONNECTION="" 9 | ENV SCHEMA="public" 10 | ENV JWT_IDENTIFIER="" 11 | ENV JWT_SECRET="" 12 | 13 | EXPOSE 80 14 | 15 | CMD ["sh", "-c", "neb run -c $DB_CONNECTION --schema $SCHEMA --jwt-identifier \"$JWT_IDENTIFIER\" --jwt-secret \"$JWT_SECRET\" --port 80"] 16 | -------------------------------------------------------------------------------- /src/test/exclude_table_test.py: -------------------------------------------------------------------------------- 1 | SQL_SETUP = """ 2 | create table friend( 3 | id serial primary key, 4 | email text, 5 | dob date, 6 | created_at timestamp 7 | ); 8 | 9 | comment on table friend is E'@exclude read'; 10 | """ 11 | 12 | 13 | def test_exclude_table(schema_builder): 14 | schema = schema_builder(SQL_SETUP) 15 | assert schema.query_type is None 16 | assert schema.mutation_type is not None 17 | -------------------------------------------------------------------------------- /src/nebulo/server/exception.py: -------------------------------------------------------------------------------- 1 | from starlette.exceptions import HTTPException 2 | from starlette.requests import Request 3 | from starlette.responses import JSONResponse 4 | 5 | 6 | async def http_exception(request: Request, exc: HTTPException): 7 | """Starlette exception handler converting starlette.exceptions.HTTPException into GraphQL responses""" 8 | return JSONResponse({"data": None, "errors": [exc.detail]}, status_code=exc.status_code) 9 | -------------------------------------------------------------------------------- /src/nebulo/gql/relay/page_info.py: -------------------------------------------------------------------------------- 1 | from nebulo.gql.alias import Boolean, Field, NonNull, ObjectType 2 | from nebulo.gql.relay.cursor import Cursor 3 | 4 | PageInfo = ObjectType( 5 | name="PageInfo", 6 | fields={ 7 | "hasNextPage": Field(NonNull(Boolean)), 8 | "hasPreviousPage": Field(NonNull(Boolean)), 9 | "startCursor": Field(Cursor), 10 | "endCursor": Field(Cursor), 11 | }, 12 | description="", 13 | ) 14 | -------------------------------------------------------------------------------- /src/nebulo/text_utils/pluralize.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | from inflect import engine 4 | 5 | 6 | @lru_cache() 7 | def get_pluralizer() -> engine: 8 | """Return an instance of inflection library's engine. 9 | This is wrapped in a function to reduce import side effects""" 10 | return engine() 11 | 12 | 13 | @lru_cache() 14 | def to_plural(text: str) -> str: 15 | pluralizer = get_pluralizer() 16 | return pluralizer.plural(text) 17 | -------------------------------------------------------------------------------- /src/test/rename_column_test.py: -------------------------------------------------------------------------------- 1 | SQL_SETUP = """ 2 | create table person( 3 | id serial primary key, 4 | email text, 5 | dob date, 6 | created_at timestamp 7 | ); 8 | 9 | comment on column person.email is E'@name contact_email'; 10 | """ 11 | 12 | 13 | def test_reflect_column_name_directive(schema_builder): 14 | schema = schema_builder(SQL_SETUP) 15 | 16 | assert "email" not in schema.type_map["Person"].fields 17 | assert "contact_email" in schema.type_map["Person"].fields 18 | -------------------------------------------------------------------------------- /src/test/rename_table_test.py: -------------------------------------------------------------------------------- 1 | SQL_SETUP = """ 2 | create table friend( 3 | id serial primary key, 4 | email text, 5 | dob date, 6 | created_at timestamp 7 | ); 8 | 9 | comment on table friend is E'@name Associate'; 10 | """ 11 | 12 | 13 | def test_reflect_table_name_directive(schema_builder): 14 | schema = schema_builder(SQL_SETUP) 15 | 16 | assert "Friend" not in schema.type_map 17 | assert "friend" not in schema.type_map 18 | assert "Associate" in schema.type_map 19 | -------------------------------------------------------------------------------- /examples/sqla_interop/README.md: -------------------------------------------------------------------------------- 1 | # Nebulo Example: 2 | 3 | app.py is a minimal example showing how to integrate nebulo with an existing SQLAlchemy project. 4 | 5 | 6 | # Start the database 7 | ```shell 8 | docker run --rm --name nebulo_example_local -p 5522:5432 -d -e POSTGRES_DB=nebulo_example -e POSTGRES_PASSWORD=app_password -e POSTGRES_USER=app_user -d postgres 9 | ``` 10 | 11 | Then start the webserver 12 | 13 | ```shell 14 | python app.py 15 | ``` 16 | 17 | Explore the GraphQL API at [http://0.0.0.0:5034/graphiql](http://0.0.0.0:5034/graphiql) 18 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit_hooks.yaml: -------------------------------------------------------------------------------- 1 | name: pre-commit hooks 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | 9 | steps: 10 | - name: Checkout 11 | uses: actions/checkout@v2 12 | 13 | - name: Set up Python 3.7 14 | uses: actions/setup-python@v1 15 | with: 16 | python-version: 3.7 17 | 18 | - name: Install Pre-Commit 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install pre-commit 22 | 23 | 24 | - name: Run Hooks 25 | run: | 26 | pre-commit run --all-files 27 | -------------------------------------------------------------------------------- /src/nebulo/text_utils/base64.py: -------------------------------------------------------------------------------- 1 | from base64 import b64decode as _unbase64 2 | from base64 import b64encode as _base64 3 | 4 | import sqlalchemy 5 | from sqlalchemy import cast, func 6 | from sqlalchemy.dialects.postgresql import BYTEA 7 | 8 | 9 | def to_base64(string): 10 | return _base64(string.encode("utf-8")).decode("utf-8") 11 | 12 | 13 | def from_base64(string): 14 | return _unbase64(string).decode("utf-8") 15 | 16 | 17 | def to_base64_sql(text_to_encode): 18 | encoding = "base64" 19 | return func.encode(cast(text_to_encode, BYTEA()), cast(encoding, sqlalchemy.Text())) 20 | -------------------------------------------------------------------------------- /src/test/exclude_fkey_test.py: -------------------------------------------------------------------------------- 1 | SQL_UP = """ 2 | CREATE TABLE public.person ( 3 | id serial primary key 4 | ); 5 | 6 | CREATE TABLE public.address ( 7 | id serial primary key, 8 | person_id int, 9 | 10 | constraint fk_person foreign key (person_id) references public.person (id) 11 | ); 12 | 13 | comment on constraint fk_person on public.address is '@name Person Addresses 14 | @exclude read 15 | '; 16 | """ 17 | 18 | 19 | def test_fkey_comment_exclude_one(schema_builder): 20 | schema = schema_builder(SQL_UP) 21 | assert "Addresses" not in schema.type_map["Person"].fields 22 | assert "Person" not in schema.type_map["Address"].fields 23 | -------------------------------------------------------------------------------- /docs/supported_entities.md: -------------------------------------------------------------------------------- 1 | # Supported Entities 2 | 3 | There is automatic reflection support for: 4 | 5 | - Tables 6 | - Views 7 | - Functions 8 | 9 | 10 | 11 | ### Views 12 | 13 | To reflect correctly, views must 14 | 15 | ``` 16 | COMMENT ON TABLE [schema].[table] 17 | IS E'@primary_key (part_1, part_2)'; 18 | ``` 19 | 20 | Views also support a virtual foreign key comment to establish a connection to the broader graph. 21 | 22 | ``` 23 | @foreign_key (variant_id, key2) references public.variant (id, key2) 24 | ``` 25 | OR 26 | ``` 27 | @foreign_key (variant_id, key2) references public.variant (id, key2) LocalNameForRemote RemoteNameForLocal' 28 | ``` 29 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /src/test/cli_test.py: -------------------------------------------------------------------------------- 1 | from click.testing import CliRunner 2 | from nebulo.cli import dump_schema, main 3 | 4 | 5 | def test_cli_version(): 6 | runner = CliRunner() 7 | resp = runner.invoke(main, ["--version"]) 8 | assert resp.exit_code == 0 9 | print(resp.output) 10 | assert "version" in resp.output 11 | 12 | 13 | def test_cli_schema_dump(app_builder, connection_str): 14 | runner = CliRunner() 15 | 16 | _ = app_builder( 17 | """ 18 | create table account ( 19 | id serial primary key 20 | ); 21 | """ 22 | ) 23 | 24 | resp = runner.invoke(dump_schema, ["-c", connection_str]) 25 | assert resp.exit_code == 0 26 | print(resp.output) 27 | assert "Query" in resp.output 28 | -------------------------------------------------------------------------------- /src/nebulo/text_utils/casing.py: -------------------------------------------------------------------------------- 1 | import re 2 | from functools import lru_cache 3 | 4 | __all__ = ["snake_to_camel"] 5 | 6 | _re_snake_to_camel = re.compile(r"(_)([a-z\d])") 7 | 8 | 9 | @lru_cache() 10 | def snake_to_camel(s: str, upper: bool = True) -> str: 11 | """Convert from snake_case to CamelCase 12 | If upper is set, then convert to upper CamelCase, otherwise the first character 13 | keeps its case. 14 | """ 15 | s = _re_snake_to_camel.sub(lambda m: m.group(2).upper(), s) 16 | if upper: 17 | s = s[:1].upper() + s[1:] 18 | return s 19 | 20 | 21 | @lru_cache() 22 | def camel_to_snake(s: str) -> str: 23 | """Convert from CamelCase to snake_case""" 24 | return "".join(["_" + i.lower() if i.isupper() else i for i in s]).lstrip("_") 25 | -------------------------------------------------------------------------------- /src/nebulo/gql/resolve/resolvers/default.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from graphql.pyutils import Path 4 | from nebulo.gql.alias import ResolveInfo 5 | 6 | 7 | def default_resolver(_, info: ResolveInfo, **kwargs) -> typing.Any: 8 | """Expects the final, complete result to exist in context['result'] 9 | and uses the current path to retrieve and return the expected result for 10 | the current location in the query 11 | 12 | Why TF would you do this? 13 | - To avoid writing a custom executor 14 | - To keep the resolver tree extensible for end users 15 | - To keep everything as default as possible 16 | """ 17 | path: Path = info.path 18 | context = info.context 19 | final_result = context["result"] 20 | 21 | remain_result = final_result 22 | for elem in path.as_list(): 23 | remain_result = remain_result[elem] 24 | 25 | return remain_result 26 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /src/test/exclude_column_test.py: -------------------------------------------------------------------------------- 1 | SQL_SETUP = """ 2 | create table person( 3 | id serial primary key, 4 | email text, 5 | dob date, 6 | created_at timestamp 7 | ); 8 | 9 | comment on column person.created_at is E'@exclude create, update'; 10 | comment on column person.email is E'@exclude update'; 11 | comment on column person.dob is E'@exclude read'; 12 | """ 13 | 14 | 15 | def test_reflect_function(schema_builder): 16 | schema = schema_builder(SQL_SETUP) 17 | 18 | assert "id" in schema.type_map["Person"].fields 19 | assert "dob" not in schema.type_map["Person"].fields 20 | assert "createdAt" in schema.type_map["Person"].fields 21 | assert "createdAt" not in schema.type_map["PersonInput"].fields 22 | assert "createdAt" not in schema.type_map["PersonPatch"].fields 23 | assert "email" in schema.type_map["PersonInput"].fields 24 | assert "email" not in schema.type_map["PersonPatch"].fields 25 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Nebulo 2 | site_description: Instant GraphQL API from PostgreSQL Database 3 | 4 | theme: 5 | name: 'readthedocs' 6 | 7 | repo_name: olirice/nebulo 8 | repo_url: https://github.com/olirice/nebulo 9 | 10 | nav: 11 | - Welcome: 12 | - Introduction: 'index.md' 13 | - Quickstart: 'quickstart.md' 14 | - Feature Support: 15 | - Types: 'supported_types.md' 16 | - Entities: 'supported_entities.md' 17 | - Reference: 18 | - CLI: 'cli.md' 19 | - API: 'public_api.md' 20 | - Comment Directives: 'comment_directives.md' 21 | - Examples: 22 | - SQLAlchemy Interop: 'sqlalchemy_interop.md' 23 | - Row Level Security: 'row_level_security.md' 24 | - Advanced: 25 | - Authentication: 'authentication.md' 26 | - Performance: 'performance.md' 27 | 28 | markdown_extensions: 29 | - codehilite 30 | - admonition 31 | - pymdownx.inlinehilite 32 | - mkautodoc 33 | -------------------------------------------------------------------------------- /src/test/integration_array_test.py: -------------------------------------------------------------------------------- 1 | from nebulo.gql.relay.node_interface import NodeIdStructure 2 | 3 | SQL_UP = """ 4 | create table book( 5 | id serial primary key, 6 | flags bool[] 7 | ); 8 | 9 | insert into book(flags) values 10 | ('{true,false}'::bool[]); 11 | """ 12 | 13 | 14 | def test_query_multiple_fields(client_builder): 15 | client = client_builder(SQL_UP) 16 | book_id = 1 17 | node_id = NodeIdStructure(table_name="book", values={"id": book_id}).serialize() 18 | gql_query = f""" 19 | {{ 20 | book(nodeId: "{node_id}") {{ 21 | id 22 | flags 23 | }} 24 | }} 25 | """ 26 | with client: 27 | resp = client.post("/", json={"query": gql_query}) 28 | assert resp.status_code == 200 29 | 30 | result = resp.json() 31 | 32 | assert result["errors"] == [] 33 | assert result["data"]["book"]["id"] == book_id 34 | assert result["data"]["book"]["flags"] == [True, False] 35 | -------------------------------------------------------------------------------- /src/test/integration_enum_test.py: -------------------------------------------------------------------------------- 1 | from nebulo.gql.relay.node_interface import NodeIdStructure 2 | 3 | SQL_UP = """ 4 | CREATE TYPE light_color AS ENUM ('red', 'green', 'blue'); 5 | 6 | create table light( 7 | id serial primary key, 8 | color light_color 9 | ); 10 | 11 | insert into light(color) values 12 | ('green'); 13 | """ 14 | 15 | 16 | def test_query_multiple_fields(client_builder): 17 | client = client_builder(SQL_UP) 18 | light_id = 1 19 | node_id = NodeIdStructure(table_name="light", values={"id": light_id}).serialize() 20 | gql_query = f""" 21 | {{ 22 | light(nodeId: "{node_id}") {{ 23 | id 24 | color 25 | }} 26 | }} 27 | """ 28 | with client: 29 | resp = client.post("/", json={"query": gql_query}) 30 | assert resp.status_code == 200 31 | 32 | result = resp.json() 33 | 34 | assert result["errors"] == [] 35 | assert result["data"]["light"]["id"] == light_id 36 | assert result["data"]["light"]["color"] == "green" 37 | -------------------------------------------------------------------------------- /docs/authentication.md: -------------------------------------------------------------------------------- 1 | ## Authentication 2 | 3 | The approach to authentication and row level security were shamelessly borrowed from the [postgraphile](https://www.graphile.org/) project, a similar (but much more mature) PostgreSQL to GraphQL project written in Javascript. 4 | 5 | Please see [their security documentation](https://www.graphile.org/postgraphile/security/) for best practices when setting up secure auth for api user roles. 6 | 7 | For a complete usage, see [the row level security example project](https://github.com/olirice/nebulo/tree/master/examples/row_level_security) 8 | 9 | 10 | Note that JWT identifier and JWT secret can be passed to the nebulo CLI via `--jwt-identifier "public.my_jwt"` and `--jwt-secret "my_jwt_secret"` to cause functions returning the JWT type to reflect correctly. 11 | 12 | If the JWT type contains a `role` field, that role will be used to execute SQL statements for authenticated user's requests. Anonymous users' requests execute with the connections default role unless a default role is passed using the `--default-role` option 13 | 14 | -------------------------------------------------------------------------------- /src/nebulo/sql/reflection/names.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=unused-argument 2 | from __future__ import annotations 3 | 4 | from nebulo.text_utils import camel_to_snake, snake_to_camel, to_plural 5 | from sqlalchemy import Table 6 | 7 | 8 | def rename_table(base, tablename: str, table: Table) -> str: 9 | """Produce a 'camelized' class name, e.g. 'words_and_underscores' -> 'WordsAndUnderscores'""" 10 | return snake_to_camel(tablename, upper=True) 11 | 12 | 13 | def rename_to_one_collection(base, local_cls, referred_cls, constraint) -> str: 14 | referred_name = camel_to_snake(referred_cls.__name__).lower() 15 | 16 | return referred_name + "_by_" + "_and".join(col.name for col in constraint.columns) 17 | 18 | 19 | def rename_to_many_collection(base, local_cls, referred_cls, constraint) -> str: 20 | "Produce an 'uncamelized', 'pluralized' class name, e.g. " 21 | "'SomeTerm' -> 'some_terms'" 22 | referred_name = camel_to_snake(referred_cls.__name__).lower() 23 | pluralized = to_plural(referred_name) 24 | return pluralized + "_by_" + "_and".join(col.name for col in constraint.columns) 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2020] [Oliver Rice] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/mirrors-isort 3 | rev: v5.7.0 4 | hooks: 5 | - id: isort 6 | args: ['--multi-line=3', '--trailing-comma', '--force-grid-wrap=0', '--use-parentheses', '--line-width=120'] 7 | 8 | - repo: https://github.com/pre-commit/pre-commit-hooks 9 | rev: v3.4.0 10 | hooks: 11 | - id: trailing-whitespace 12 | - id: check-yaml 13 | - id: check-toml 14 | - id: check-merge-conflict 15 | - id: check-added-large-files 16 | args: ['--maxkb=200'] 17 | - id: mixed-line-ending 18 | args: ['--fix=lf'] 19 | 20 | - repo: https://github.com/humitos/mirrors-autoflake 21 | rev: v1.1 22 | hooks: 23 | - id: autoflake 24 | args: ['--in-place', '--remove-all-unused-imports'] 25 | 26 | - repo: https://github.com/ambv/black 27 | rev: 20.8b1 28 | hooks: 29 | - id: black 30 | language_version: python3.8 31 | args: ['--line-length=120'] 32 | 33 | - repo: https://github.com/pre-commit/mirrors-mypy 34 | rev: v0.800 35 | hooks: 36 | - id: mypy 37 | additional_dependencies: ['sqlalchemy-stubs'] 38 | -------------------------------------------------------------------------------- /src/test/integration_composite_test.py: -------------------------------------------------------------------------------- 1 | from nebulo.gql.relay.node_interface import NodeIdStructure 2 | 3 | SQL_UP = """ 4 | CREATE TYPE full_name AS ( 5 | first_name text, 6 | last_name text 7 | ); 8 | 9 | create table account ( 10 | id serial primary key, 11 | name full_name not null 12 | ); 13 | 14 | insert into account(name) values 15 | (('oliver', 'rice')); 16 | """ 17 | 18 | 19 | def test_query_multiple_fields(client_builder): 20 | client = client_builder(SQL_UP) 21 | account_id = 1 22 | node_id = NodeIdStructure(table_name="account", values={"id": account_id}).serialize() 23 | gql_query = f""" 24 | {{ 25 | account(nodeId: "{node_id}") {{ 26 | id 27 | name {{ 28 | first_name 29 | }} 30 | }} 31 | }} 32 | """ 33 | with client: 34 | resp = client.post("/", json={"query": gql_query}) 35 | assert resp.status_code == 200 36 | 37 | result = resp.json() 38 | 39 | assert result["errors"] == [] 40 | assert result["data"]["account"]["id"] == account_id 41 | assert result["data"]["account"]["name"]["first_name"] == "oliver" 42 | -------------------------------------------------------------------------------- /src/test/reflect_constraint_comment_test.py: -------------------------------------------------------------------------------- 1 | from nebulo.sql.inspect import get_comment, get_constraints, get_table_name 2 | from nebulo.sql.reflection.manager import reflect_sqla_models 3 | 4 | SQL_UP = """ 5 | CREATE TABLE public.person ( 6 | id serial primary key 7 | ); 8 | 9 | CREATE TABLE public.address ( 10 | id serial primary key, 11 | person_id int, 12 | 13 | constraint fk_person foreign key (person_id) references public.person (id) 14 | ); 15 | 16 | comment on constraint fk_person on public.address is '@name Person Addresses'; 17 | """ 18 | 19 | 20 | def test_reflect_fkey_comment(engine): 21 | engine.execute(SQL_UP) 22 | tables, _ = reflect_sqla_models(engine, schema="public") 23 | 24 | table = [x for x in tables if get_table_name(x) == "address"][0] 25 | 26 | constraint = [x for x in get_constraints(table) if x.name == "fk_person"][0] 27 | assert get_comment(constraint) == "@name Person Addresses" 28 | 29 | 30 | def test_reflect_fkey_comment_to_schema(schema_builder): 31 | schema = schema_builder(SQL_UP) 32 | assert "Addresses" in schema.type_map["Person"].fields 33 | assert "Person" in schema.type_map["Address"].fields 34 | -------------------------------------------------------------------------------- /src/test/create_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | SQL_UP = """ 4 | CREATE TABLE account ( 5 | id serial primary key, 6 | name text not null, 7 | created_at timestamp without time zone default (now() at time zone 'utc') 8 | ); 9 | 10 | INSERT INTO account (id, name) VALUES 11 | (10, 'oliver'), 12 | (20, 'rachel'), 13 | (30, 'sophie'); 14 | """ 15 | 16 | 17 | def test_create_mutation(client_builder): 18 | client = client_builder(SQL_UP) 19 | query = """ 20 | 21 | mutation { 22 | createAccount(input: { 23 | clientMutationId: "a44df", 24 | account: { 25 | name: "Buddy" 26 | } 27 | }) { 28 | cid: clientMutationId 29 | account { 30 | dd: id 31 | nodeId 32 | name 33 | } 34 | } 35 | } 36 | """ 37 | 38 | with client: 39 | resp = client.post("/", json={"query": query}) 40 | assert resp.status_code == 200 41 | payload = json.loads(resp.text) 42 | print(payload) 43 | assert isinstance(payload["data"]["createAccount"]["account"], dict) 44 | assert payload["data"]["createAccount"]["account"]["dd"] 45 | assert payload["data"]["createAccount"]["account"]["name"] == "Buddy" 46 | assert len(payload["errors"]) == 0 47 | -------------------------------------------------------------------------------- /src/test/delete_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from nebulo.gql.relay.node_interface import NodeIdStructure 4 | 5 | SQL_UP = """ 6 | CREATE TABLE account ( 7 | id serial primary key, 8 | name text not null, 9 | created_at timestamp without time zone default (now() at time zone 'utc') 10 | ); 11 | 12 | INSERT INTO account (id, name) VALUES 13 | (1, 'oliver'), 14 | (2, 'rachel'), 15 | (3, 'sophie'); 16 | """ 17 | 18 | 19 | def test_delete_mutation(client_builder): 20 | client = client_builder(SQL_UP) 21 | account_id = 1 22 | node_id = NodeIdStructure(table_name="account", values={"id": account_id}).serialize() 23 | query = f""" 24 | 25 | mutation {{ 26 | deleteAccount(input: {{ 27 | clientMutationId: "gjwl" 28 | nodeId: "{node_id}" 29 | }}) {{ 30 | cid: clientMutationId 31 | nodeId 32 | }} 33 | }} 34 | """ 35 | 36 | with client: 37 | resp = client.post("/", json={"query": query}) 38 | assert resp.status_code == 200 39 | payload = json.loads(resp.text) 40 | print(payload) 41 | assert isinstance(payload["data"]["deleteAccount"], dict) 42 | assert payload["data"]["deleteAccount"]["cid"] == "gjwl" 43 | assert payload["data"]["deleteAccount"]["nodeId"] == node_id 44 | assert len(payload["errors"]) == 0 45 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | 9 | services: 10 | 11 | postgres: 12 | image: postgres:12 13 | env: 14 | POSTGRES_HOST: localhost 15 | POSTGRES_USER: nebulo_user 16 | POSTGRES_PASSWORD: password 17 | POSTGRES_DB: nebulo_db 18 | 19 | ports: 20 | - 4442:5432 21 | options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 22 | 23 | steps: 24 | - uses: actions/checkout@v1 25 | - name: Set up Python 3.7 26 | uses: actions/setup-python@v1 27 | with: 28 | python-version: 3.7 29 | - name: Install nebulo 30 | run: | 31 | sudo apt-get install libpq-dev 32 | python -m pip install --upgrade pip 33 | pip install ".[test]" 34 | 35 | - name: Run Tests with Coverage 36 | run: | 37 | pytest --cov=nebulo src/test --cov-report=xml 38 | 39 | - name: Upload coverage to Codecov 40 | uses: codecov/codecov-action@v1 41 | with: 42 | token: ${{ secrets.CODECOV_TOKEN }} 43 | file: ./coverage.xml 44 | flags: unittests 45 | name: codecov-umbrella 46 | fail_ci_if_error: true 47 | -------------------------------------------------------------------------------- /src/nebulo/lexer.py: -------------------------------------------------------------------------------- 1 | from pygments.lexer import RegexLexer 2 | from pygments.token import Comment, Keyword, Name, Number, Operator, Punctuation, String, Text 3 | 4 | __all__ = ["GraphQLLexer"] 5 | 6 | 7 | class GraphQLLexer(RegexLexer): 8 | """ 9 | Pygments GraphQL lexer for mkdocs 10 | """ 11 | 12 | name = "GraphQL" 13 | aliases = ["graphql", "gql"] 14 | filenames = ["*.graphql", "*.gql"] 15 | mimetypes = ["application/graphql"] 16 | 17 | tokens = { 18 | "root": [ 19 | (r"#.*", Comment.Singline), 20 | (r'"""("?"?(\\"""|\\(?!=""")|[^"\\]))*"""', Comment.Multi), 21 | (r"\.\.\.", Operator), 22 | # u 23 | (r'".*"', String.Double), 24 | (r"(-?0|-?[1-9][0-9]*)(\.[0-9]+[eE][+-]?[0-9]+|\.[0-9]+|[eE][+-]?[0-9]+)", Number.Float), 25 | (r"(-?0|-?[1-9][0-9]*)", Number.Integer), 26 | (r"\$+[_A-Za-z][_0-9A-Za-z]*", Name.Variable), 27 | (r"[_A-Za-z][_0-9A-Za-z]+\s?:", Text), 28 | ( 29 | r"(type|query|interface|mutation|extend|input|implements|directive|@[a-z]+|on|true|false|null)\b", 30 | Keyword.Type, 31 | ), 32 | (r"[!$():=@\[\]{|}]+?", Punctuation), 33 | (r"[_A-Za-z][_0-9A-Za-z]*", Keyword), 34 | (r"(\s|,)", Text), 35 | ] 36 | } 37 | -------------------------------------------------------------------------------- /src/test/jwt_function_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | CREATE_FUNCTION = """ 4 | -- Not used 5 | CREATE TABLE account ( 6 | id serial primary key, 7 | name text not null, 8 | created_at timestamp without time zone default (now() at time zone 'utc') 9 | ); 10 | 11 | 12 | create type public.jwt_token as ( 13 | role text, 14 | exp integer, 15 | email text 16 | ); 17 | 18 | create function public.authenticate(email text, password text) 19 | returns public.jwt_token 20 | as $$ 21 | begin 22 | return ('user', 100, email)::public.jwt_token; 23 | end; 24 | $$ 25 | language plpgsql 26 | strict 27 | security definer; 28 | """ 29 | 30 | 31 | def test_jwt_function(client_builder): 32 | client = client_builder(CREATE_FUNCTION, jwt_identifier="public.jwt_token", jwt_secret="secret") 33 | 34 | query = """ 35 | mutation { 36 | authenticate(input: {password: "super_secret", email: "o@r.com", clientMutationId: "some_client_id"}) { 37 | result 38 | clientMutationId 39 | } 40 | } 41 | """ 42 | 43 | with client: 44 | resp = client.post("/", json={"query": query}) 45 | result = json.loads(resp.text) 46 | assert resp.status_code == 200 47 | 48 | print(result) 49 | assert result["errors"] == [] 50 | token = result["data"]["authenticate"]["result"] 51 | assert isinstance(token, str) 52 | assert len(token) > 10 53 | -------------------------------------------------------------------------------- /src/nebulo/sql/composite.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections import namedtuple 4 | from typing import List, NamedTuple, Type 5 | 6 | from sqlalchemy import Column 7 | from sqlalchemy.sql.type_api import TypeEngine 8 | from sqlalchemy.types import UserDefinedType 9 | 10 | 11 | class CompositeType(UserDefinedType): # type: ignore 12 | """ 13 | Represents a PostgreSQL composite type. 14 | 15 | :param name: 16 | Name of the composite type. 17 | :param columns: 18 | List of columns that this composite type consists of 19 | """ 20 | 21 | python_type = tuple 22 | 23 | name: str 24 | columns: List[Column] = [] 25 | type_cls: Type[NamedTuple] 26 | 27 | pg_name: str 28 | pg_schema: str 29 | 30 | def init(self, *args, **kwargs): 31 | pass 32 | 33 | 34 | def composite_type_factory(name: str, columns: List[Column], pg_name: str, pg_schema: str) -> TypeEngine: 35 | 36 | for column in columns: 37 | column.key = column.name 38 | 39 | type_cls: Type[NamedTuple] = namedtuple(name, [c.name for c in columns]) # type: ignore 40 | 41 | composite = type( 42 | name, 43 | (CompositeType,), 44 | { 45 | "name": name, 46 | "columns": columns, 47 | "type_cls": type_cls, 48 | "pg_name": pg_name, 49 | "pg_schema": pg_schema, 50 | }, # type:ignore 51 | ) 52 | return composite # type: ignore 53 | -------------------------------------------------------------------------------- /src/nebulo/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from appdirs import user_data_dir 5 | 6 | 7 | class EnvManager: 8 | """Stashes environment variables in a file and 9 | retrieves them in (a different process) with get_environ 10 | 11 | with failover to os.environ 12 | """ 13 | 14 | app_env_dir = Path(user_data_dir("NEBULO")) 15 | app_env = app_env_dir / ".env" 16 | 17 | def __init__(self, **env_vars): 18 | # Delete if exists 19 | try: 20 | os.remove(self.app_env) 21 | except OSError: 22 | pass 23 | 24 | self.app_env_dir.mkdir(parents=True, exist_ok=True) 25 | self.app_env.touch() 26 | self.vars = env_vars 27 | 28 | def __enter__(self): 29 | with self.app_env.open("w") as env_file: 30 | for key, val in self.vars.items(): 31 | if val is not None: 32 | env_file.write(f"{key}={val}\n") 33 | return self 34 | 35 | def __exit__(self, exc_type, exc_value, exc_traceback): 36 | try: 37 | os.remove(self.app_env) 38 | except OSError: 39 | pass 40 | 41 | @classmethod 42 | def get_environ(cls): 43 | try: 44 | with cls.app_env.open("r") as f: 45 | for row in f: 46 | key, value = row.split("=", 1) 47 | os.environ[key.strip()] = value.strip() 48 | except FileNotFoundError: 49 | pass 50 | return os.environ 51 | -------------------------------------------------------------------------------- /examples/row_level_security/README.md: -------------------------------------------------------------------------------- 1 | # Nebulo Row Level Security Example: 2 | 3 | This example leverages PostgreSQL [roles](https://www.postgresql.org/docs/8.1/user-manag.html) and [row level security](https://www.postgresql.org/docs/current/ddl-rowsecurity.html) with nebulo's builtin [JWT](https://jwt.io/) support to expose a GraphQL API with different permissions for anonymous and authenticated users. 4 | 5 | Permissions are scoped such that: 6 | 7 | - Anonymous users may: 8 | - Create an account 9 | - See blog posts 10 | 11 | - Authenticated users may: 12 | - Edit their own account info 13 | - Create a blog post 14 | - Edit their own blog posts 15 | - See blog posts 16 | 17 | Start the database 18 | ```shell 19 | docker run --rm --name nebulo_rls -p 5523:5432 -d -e POSTGRES_DB=nebulo_example -e POSTGRES_PASSWORD=app_password -e POSTGRES_USER=app_user -d postgres 20 | ``` 21 | 22 | Copy the SQL schema setup script into the container 23 | ```shell 24 | docker cp ./setup.sql nebulo_rls:/docker-entrypoint-initdb.d/setup.sql 25 | ``` 26 | 27 | Run the SQL setup script on the database 28 | ```shell 29 | docker exec -u postgres nebulo_rls psql nebulo_example app_user -f docker-entrypoint-initdb.d/setup.sql 30 | ``` 31 | 32 | Then start the webserver 33 | ```shell 34 | neb run -c postgresql://nebulo_user:password@localhost:4443/nebulo_db --default-role anon_api_user --jwt-identifier public.jwt --jwt-secret super_duper_secret 35 | ``` 36 | 37 | 38 | Explore the GraphQL API at [http://0.0.0.0:5034/graphiql](http://0.0.0.0:5034/graphiql) 39 | -------------------------------------------------------------------------------- /src/test/update_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from nebulo.gql.relay.node_interface import NodeIdStructure 4 | 5 | SQL_UP = """ 6 | CREATE TABLE account ( 7 | id serial primary key, 8 | name text not null, 9 | created_at timestamp without time zone default (now() at time zone 'utc') 10 | ); 11 | 12 | INSERT INTO account (id, name) VALUES 13 | (1, 'oliver'), 14 | (2, 'rachel'), 15 | (3, 'sophie'); 16 | """ 17 | 18 | 19 | def test_update_mutation(client_builder): 20 | client = client_builder(SQL_UP) 21 | account_id = 1 22 | node_id = NodeIdStructure(table_name="account", values={"id": account_id}).serialize() 23 | query = f""" 24 | 25 | mutation {{ 26 | updateAccount(input: {{ 27 | clientMutationId: "gjwl" 28 | nodeId: "{node_id}" 29 | account: {{ 30 | name: "Buddy" 31 | }} 32 | }}) {{ 33 | clientMutationId 34 | account {{ 35 | dd: id 36 | nodeId 37 | name 38 | }} 39 | }} 40 | }} 41 | """ 42 | 43 | with client: 44 | resp = client.post("/", json={"query": query}) 45 | assert resp.status_code == 200 46 | payload = json.loads(resp.text) 47 | assert isinstance(payload["data"]["updateAccount"]["account"], dict) 48 | assert payload["data"]["updateAccount"]["account"]["dd"] == account_id 49 | assert payload["data"]["updateAccount"]["account"]["nodeId"] == node_id 50 | assert payload["data"]["updateAccount"]["account"]["name"] == "Buddy" 51 | assert payload["data"]["updateAccount"]["clientMutationId"] == "gjwl" 52 | assert len(payload["errors"]) == 0 53 | -------------------------------------------------------------------------------- /src/nebulo/server/jwt.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Awaitable, Callable, Dict, Optional 2 | 3 | import jwt 4 | from jwt.exceptions import DecodeError, ExpiredSignatureError, PyJWTError 5 | from starlette.exceptions import HTTPException 6 | from starlette.requests import Request 7 | 8 | 9 | def get_jwt_claims_handler(secret: Optional[str]) -> Callable[[Request], Awaitable[Dict[str, Any]]]: 10 | """Return a function that retrieves and decodes JWT claims from the Starlette Request""" 11 | 12 | async def get_jwt_claims(request: Request) -> Dict[str, Any]: 13 | """Retrieve the JWT claims from the Starlette Request""" 14 | 15 | if secret is None: 16 | return {} 17 | 18 | if "Authorization" not in request.headers: 19 | return {} 20 | 21 | auth = request.headers["Authorization"] 22 | try: 23 | scheme, token = auth.split() 24 | if scheme.lower() == "bearer": 25 | contents = jwt.decode(token, secret, algorithms=["HS256"]) 26 | return contents 27 | raise HTTPException(401, "Invalid JWT Authorization header. Expected 'Bearer '") 28 | except DecodeError: 29 | raise HTTPException(401, "Invalid JWT credentials") 30 | except ExpiredSignatureError: 31 | raise HTTPException(401, "JWT has expired. Please reauthenticate") 32 | # Generically catch all PyJWT errors 33 | except PyJWTError as exc: 34 | raise HTTPException(401, str(exc)) 35 | return {} 36 | 37 | return get_jwt_claims 38 | -------------------------------------------------------------------------------- /src/test/node_id_test.py: -------------------------------------------------------------------------------- 1 | from nebulo.gql.relay.node_interface import NodeIdStructure 2 | 3 | SQL_UP = """ 4 | CREATE TABLE account ( 5 | id serial primary key, 6 | name text 7 | ); 8 | 9 | INSERT INTO account (id) VALUES 10 | (1), 11 | (2), 12 | (3); 13 | """ 14 | 15 | 16 | def test_round_trip_node_id(client_builder): 17 | client = client_builder(SQL_UP) 18 | 19 | account_id = 1 20 | node_id = NodeIdStructure(table_name="account", values={"id": account_id}).serialize() 21 | 22 | gql_query = f""" 23 | {{ 24 | account(nodeId: "{node_id}") {{ 25 | nodeId 26 | }} 27 | }} 28 | """ 29 | with client: 30 | resp = client.post("/", json={"query": gql_query}) 31 | assert resp.status_code == 200 32 | 33 | result = resp.json() 34 | assert result["errors"] == [] 35 | 36 | assert result["data"]["account"]["nodeId"] == node_id 37 | 38 | wrong_node_id = NodeIdStructure(table_name="account", values={"id": 2}).serialize() 39 | assert result["data"]["account"]["nodeId"] != wrong_node_id 40 | 41 | 42 | def test_invalid_node_id(client_builder): 43 | client = client_builder(SQL_UP) 44 | 45 | invalid_node_id = "not_a_valid_id" 46 | 47 | gql_query = f""" 48 | {{ 49 | account(nodeId: "{invalid_node_id}") {{ 50 | nodeId 51 | }} 52 | }} 53 | """ 54 | with client: 55 | resp = client.post("/", json={"query": gql_query}) 56 | assert resp.status_code == 200 57 | 58 | result = resp.json() 59 | assert len(result["errors"]) == 1 60 | assert "Expected value of type" in str(result["errors"][0]) 61 | -------------------------------------------------------------------------------- /src/nebulo/gql/resolve/resolvers/claims.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from sqlalchemy import Text, func, literal, select 4 | from sqlalchemy.sql.selectable import Select 5 | 6 | 7 | def build_claims(jwt_claims: typing.Dict[str, typing.Any], default_role: typing.Optional[str] = None) -> Select: 8 | """Emit statement to set 'jwt.claims.' for each claim in claims dict 9 | and a 'role' 10 | """ 11 | # Setting local variables an not be done in prepared statement 12 | # since JWT claims are signed, literal binds should be ok 13 | 14 | # Keys with special meaning that should not be prefixed with 'jwt.claims' 15 | role_key = "role" 16 | special_keys = {role_key, "statement_timeout"} 17 | 18 | # Set all jwt.claims.* 19 | claims = [ 20 | func.set_config( 21 | literal("jwt.claims.").op("||")(func.cast(claim_key, Text())), 22 | func.cast(str(claim_value), Text()), 23 | True, 24 | ) 25 | for claim_key, claim_value in jwt_claims.items() 26 | if claim_key not in special_keys 27 | ] 28 | 29 | # Set role to default if exists and not provided by jwt 30 | if role_key not in jwt_claims and default_role is not None: 31 | jwt_claims[role_key] = default_role 32 | 33 | # Set keys with special meaning to postgres 34 | for key in special_keys: 35 | if key in jwt_claims: 36 | claims.append( 37 | func.set_config( 38 | func.cast(key, Text()), 39 | func.cast(str(jwt_claims[key]), Text()), 40 | True, 41 | ) 42 | ) 43 | 44 | return select(claims) 45 | -------------------------------------------------------------------------------- /src/test/view_test.py: -------------------------------------------------------------------------------- 1 | SQL_UP = """ 2 | create table account ( 3 | id serial primary key, 4 | name text not null 5 | ); 6 | 7 | create view account_view as 8 | select 9 | id account_id, 10 | name account_name 11 | from 12 | account; 13 | 14 | 15 | comment on view account_view IS E' 16 | @foreign_key (account_id) references public.account (id) 17 | @primary_key (account_id)'; 18 | 19 | 20 | insert into account(id, name) values 21 | (1, 'Oliver'); 22 | """ 23 | 24 | 25 | def test_reflect_view(schema_builder): 26 | schema = schema_builder(SQL_UP) 27 | 28 | query_type = schema.query_type 29 | assert "accountView" in query_type.fields 30 | 31 | mutation_type = schema.mutation_type 32 | assert "createAccountView" not in mutation_type.fields 33 | assert "updateAccountView" not in mutation_type.fields 34 | assert "deleteAccountView" not in mutation_type.fields 35 | 36 | 37 | def test_query_view(client_builder): 38 | client = client_builder(SQL_UP) 39 | 40 | gql_query = """ 41 | query { 42 | allAccountViews { 43 | edges { 44 | node { 45 | nodeId 46 | accountId 47 | accountName 48 | accountByAccountIdToId{ 49 | name 50 | } 51 | } 52 | } 53 | } 54 | } 55 | """ 56 | 57 | with client: 58 | resp = client.post("/", json={"query": gql_query}) 59 | result = resp.json() 60 | assert resp.status_code == 200 61 | 62 | print(result) 63 | assert result["errors"] == [] 64 | assert result["data"]["allAccountViews"]["edges"][0]["node"]["accountId"] == 1 65 | assert result["data"]["allAccountViews"]["edges"][0]["node"]["accountByAccountIdToId"]["name"] == "Oliver" 66 | -------------------------------------------------------------------------------- /docs/cli.md: -------------------------------------------------------------------------------- 1 | ## Command Line Interface 2 | 3 | The CLI provides a path for deploying GraphQL APIs without writting any python code. For full usage see [the row level security example](row_level_security.md) 4 | 5 | ```text 6 | Usage: neb [OPTIONS] COMMAND [ARGS]... 7 | 8 | Options: 9 | --version Show the version and exit. 10 | --help Show this message and exit. 11 | 12 | Commands: 13 | dump-schema Dump the GraphQL Schema to stdout or file 14 | run Run the GraphQL Web Server 15 | ``` 16 | 17 | 18 | #### neb run 19 | 20 | Starts the GraphQL web server and serves the interactive documentation at `http://:/graphiql` 21 | 22 | ```text 23 | Usage: neb run [OPTIONS] 24 | 25 | Run the GraphQL Web Server 26 | 27 | Options: 28 | -c, --connection TEXT Database connection string 29 | -p, --port INTEGER Web server port 30 | -h, --host TEXT Host address 31 | -w, --workers INTEGER Number of parallel workers 32 | -s, --schema TEXT SQL schema name 33 | --jwt-identifier TEXT JWT composite type identifier e.g. "public.jwt" 34 | --jwt-secret TEXT Secret key for JWT encryption 35 | --reload / --no-reload Reload if source files change 36 | --default-role TEXT Default PostgreSQL role for anonymous users 37 | --help Show this message and exit. 38 | ``` 39 | 40 | 41 | 42 | #### neb dump-schema 43 | 44 | Export the GraphQL schema 45 | 46 | ```text 47 | Usage: neb dump-schema [OPTIONS] 48 | 49 | Dump the GraphQL Schema to stdout or file 50 | 51 | Options: 52 | -c, --connection TEXT Database connection string 53 | -s, --schema TEXT SQL schema name 54 | -o, --out-file FILENAME Output file path 55 | --help Show this message and exit. 56 | ``` 57 | -------------------------------------------------------------------------------- /src/test/function_returns_row_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from nebulo.sql.reflection.function import reflect_functions 4 | from sqlalchemy.dialects.postgresql import base as pg_base 5 | 6 | CREATE_FUNCTION = """ 7 | create table account( 8 | id int primary key, 9 | name text 10 | ); 11 | 12 | insert into account (id, name) 13 | values (1, 'oli'); 14 | 15 | create function get_account(id int) 16 | returns account 17 | as $$ 18 | select (1, 'oli')::account; 19 | $$ language sql; 20 | """ 21 | 22 | 23 | def test_reflect_function_returning_row(engine, session): 24 | session.execute(CREATE_FUNCTION) 25 | session.commit() 26 | 27 | functions = reflect_functions(engine, schema="public", type_map=pg_base.ischema_names) 28 | 29 | get_account = functions[0] 30 | res = session.execute(get_account.to_executable([1])).first() 31 | print(res) 32 | # psycopg2 does not know how to deserialize row results 33 | assert res == ("(1,oli)",) 34 | 35 | 36 | def test_integration_function(client_builder): 37 | client = client_builder(CREATE_FUNCTION) 38 | 39 | query = """ 40 | mutation { 41 | getAccount(input: {id: 1, clientMutationId: "abcdef"}) { 42 | cmi: clientMutationId 43 | out: result { 44 | nodeId 45 | id 46 | } 47 | } 48 | } 49 | """ 50 | 51 | with client: 52 | resp = client.post("/", json={"query": query}) 53 | result = json.loads(resp.text) 54 | print(result) 55 | assert resp.status_code == 200 56 | assert result["errors"] == [] 57 | assert result["data"]["getAccount"]["out"]["id"] == 1 58 | assert result["data"]["getAccount"]["out"]["nodeId"] is not None 59 | assert result["data"]["getAccount"]["cmi"] == "abcdef" 60 | -------------------------------------------------------------------------------- /src/test/integration_one_basic_test.py: -------------------------------------------------------------------------------- 1 | from nebulo.gql.relay.node_interface import NodeIdStructure 2 | 3 | SQL_UP = """ 4 | CREATE TABLE account ( 5 | id serial primary key, 6 | name text not null, 7 | created_at timestamp without time zone default (now() at time zone 'utc') 8 | ); 9 | 10 | INSERT INTO account (id, name) VALUES 11 | (1, 'oliver'), 12 | (2, 'rachel'), 13 | (3, 'sophie'), 14 | (4, 'buddy'); 15 | """ 16 | 17 | 18 | def test_query_one_field(client_builder): 19 | client = client_builder(SQL_UP) 20 | account_id = 1 21 | node_id = NodeIdStructure(table_name="account", values={"id": account_id}).serialize() 22 | gql_query = f""" 23 | {{ 24 | account(nodeId: "{node_id}") {{ 25 | id 26 | }} 27 | }} 28 | """ 29 | with client: 30 | resp = client.post("/", json={"query": gql_query}) 31 | assert resp.status_code == 200 32 | 33 | result = resp.json() 34 | assert isinstance(result["data"]["account"], dict) 35 | assert result["data"]["account"]["id"] == account_id 36 | 37 | 38 | def test_query_multiple_fields(client_builder): 39 | client = client_builder(SQL_UP) 40 | account_id = 1 41 | node_id = NodeIdStructure(table_name="account", values={"id": account_id}).serialize() 42 | gql_query = f""" 43 | {{ 44 | account(nodeId: "{node_id}") {{ 45 | id 46 | name 47 | createdAt 48 | }} 49 | }} 50 | """ 51 | with client: 52 | resp = client.post("/", json={"query": gql_query}) 53 | assert resp.status_code == 200 54 | 55 | result = resp.json() 56 | 57 | assert result["errors"] == [] 58 | assert result["data"]["account"]["id"] == account_id 59 | assert result["data"]["account"]["name"] == "oliver" 60 | assert isinstance(result["data"]["account"]["createdAt"], str) 61 | -------------------------------------------------------------------------------- /src/test/integration_many_basic_test.py: -------------------------------------------------------------------------------- 1 | SQL_UP = """ 2 | CREATE TABLE account ( 3 | id serial primary key, 4 | name text not null, 5 | created_at timestamp without time zone default (now() at time zone 'utc') 6 | ); 7 | 8 | INSERT INTO account (id, name) VALUES 9 | (1, 'oliver'), 10 | (2, 'rachel'), 11 | (3, 'sophie'), 12 | (4, 'buddy'); 13 | """ 14 | 15 | 16 | def test_query_multiple_fields(client_builder): 17 | client = client_builder(SQL_UP) 18 | gql_query = f""" 19 | {{ 20 | allAccounts {{ 21 | edges {{ 22 | node {{ 23 | id 24 | name 25 | createdAt 26 | }} 27 | }} 28 | }} 29 | }} 30 | """ 31 | with client: 32 | resp = client.post("/", json={"query": gql_query}) 33 | assert resp.status_code == 200 34 | 35 | result = resp.json() 36 | assert result["errors"] == [] 37 | assert "edges" in result["data"]["allAccounts"] 38 | assert "node" in result["data"]["allAccounts"]["edges"][0] 39 | 40 | 41 | def test_arg_first(client_builder): 42 | client = client_builder(SQL_UP) 43 | gql_query = f""" 44 | {{ 45 | allAccounts(first: 2) {{ 46 | edges {{ 47 | node {{ 48 | id 49 | name 50 | createdAt 51 | }} 52 | }} 53 | }} 54 | }} 55 | """ 56 | 57 | with client: 58 | resp = client.post("/", json={"query": gql_query}) 59 | assert resp.status_code == 200 60 | 61 | result = resp.json() 62 | assert result["errors"] == [] 63 | 64 | assert len(result["data"]["allAccounts"]["edges"]) == 2 65 | assert result["data"]["allAccounts"]["edges"][0]["node"]["id"] == 1 66 | assert result["data"]["allAccounts"]["edges"][1]["node"]["id"] == 2 67 | -------------------------------------------------------------------------------- /src/test/condition_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | SQL_UP = """ 4 | CREATE TABLE account ( 5 | id serial primary key, 6 | name text not null, 7 | age int not null, 8 | created_at timestamp without time zone default (now() at time zone 'utc') 9 | ); 10 | 11 | INSERT INTO account (id, name, age) VALUES 12 | (1, 'oliver', 28), 13 | (2, 'rachel', 28), 14 | (3, 'sophie', 1), 15 | (4, 'buddy', 20); 16 | """ 17 | 18 | 19 | def test_query_with_int_condition(client_builder): 20 | client = client_builder(SQL_UP) 21 | gql_query = f""" 22 | {{ 23 | allAccounts(condition: {{id: 1}}) {{ 24 | edges {{ 25 | node {{ 26 | id 27 | }} 28 | }} 29 | }} 30 | }} 31 | """ 32 | 33 | with client: 34 | resp = client.post("/", json={"query": gql_query}) 35 | 36 | assert resp.status_code == 200 37 | print(resp.text) 38 | 39 | payload = json.loads(resp.text) 40 | assert payload["errors"] == [] 41 | 42 | result_id = payload["data"]["allAccounts"]["edges"][0]["node"]["id"] 43 | assert result_id == 1 44 | 45 | 46 | def test_query_with_string_condition(client_builder): 47 | client = client_builder(SQL_UP) 48 | 49 | gql_query = f""" 50 | {{ 51 | allAccounts(condition: {{name: "sophie"}}) {{ 52 | edges {{ 53 | node {{ 54 | id 55 | name 56 | }} 57 | }} 58 | }} 59 | }} 60 | """ 61 | 62 | with client: 63 | resp = client.post("/", json={"query": gql_query}) 64 | 65 | assert resp.status_code == 200 66 | print(resp.text) 67 | 68 | payload = json.loads(resp.text) 69 | assert payload["errors"] == [] 70 | 71 | result_name = payload["data"]["allAccounts"]["edges"][0]["node"]["name"] 72 | assert result_name == "sophie" 73 | -------------------------------------------------------------------------------- /src/test/function_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from nebulo.sql.reflection.function import reflect_functions 4 | from sqlalchemy import select 5 | from sqlalchemy.dialects.postgresql import base as pg_base 6 | 7 | CREATE_FUNCTION = """ 8 | create function public.to_upper(some_text text) 9 | returns text 10 | as $$ 11 | select upper(some_text); 12 | $$ language sql; 13 | 14 | 15 | -- solve query root type nust not be none issue 16 | create table trash ( 17 | id serial primary key, 18 | name text 19 | ); 20 | """ 21 | 22 | 23 | def test_reflect_function(engine, session): 24 | session.execute(CREATE_FUNCTION) 25 | session.commit() 26 | 27 | functions = reflect_functions(engine, schema="public", type_map=pg_base.ischema_names) 28 | 29 | to_upper = [x for x in functions if x.name == "to_upper"] 30 | assert len(to_upper) == 1 31 | 32 | 33 | def test_call_function(engine, session): 34 | session.execute(CREATE_FUNCTION) 35 | session.commit() 36 | 37 | functions = reflect_functions(engine, schema="public", type_map=pg_base.ischema_names) 38 | to_upper = [x for x in functions if x.name == "to_upper"][0] 39 | 40 | query = select([to_upper.to_executable(["abc"]).label("result")]) 41 | result = session.execute(query).fetchone()["result"] 42 | assert result == "ABC" 43 | 44 | 45 | def test_integration_function(client_builder): 46 | client = client_builder(CREATE_FUNCTION) 47 | 48 | query = """ 49 | mutation { 50 | toUpper(input: {some_text: "abc", clientMutationId: "some_client_id"}) { 51 | result 52 | clientMutationId 53 | } 54 | } 55 | """ 56 | 57 | with client: 58 | resp = client.post("/", json={"query": query}) 59 | print(resp.text) 60 | result = json.loads(resp.text) 61 | assert resp.status_code == 200 62 | assert result["errors"] == [] 63 | assert result["data"]["toUpper"]["result"] == "ABC" 64 | -------------------------------------------------------------------------------- /src/test/reflection_test.py: -------------------------------------------------------------------------------- 1 | import sqlalchemy as sqla 2 | from nebulo.gql import alias 3 | from nebulo.gql.convert.column import convert_type 4 | from sqlalchemy import MetaData 5 | from sqlalchemy.dialects.postgresql import ARRAY, DOUBLE_PRECISION 6 | from sqlalchemy.ext.automap import automap_base 7 | 8 | SQL_UP = """ 9 | CREATE TABLE public.refl_numer ( 10 | col_0 int primary key, 11 | col_1 smallint, 12 | col_2 integer, 13 | col_3 bigint, 14 | col_4 smallserial, 15 | col_5 bigserial, 16 | col_6 numeric, 17 | col_7 numeric(10,2), 18 | col_8 real, 19 | col_9 double precision, 20 | col_10 boolean, 21 | col_11 boolean[] 22 | ); 23 | """ 24 | 25 | 26 | def test_reflect_types(engine, session): 27 | session.execute(SQL_UP) 28 | session.commit() 29 | 30 | TableBase = automap_base(metadata=MetaData()) 31 | 32 | TableBase.prepare(engine, reflect=True, schema="public") 33 | tab = TableBase.classes["refl_numer"] 34 | tab.col_7.type.asdecimal = True 35 | assert isinstance(tab.col_0.type, sqla.Integer) 36 | assert isinstance(tab.col_1.type, sqla.Integer) 37 | assert isinstance(tab.col_2.type, sqla.Integer) 38 | assert isinstance(tab.col_3.type, sqla.Integer) 39 | assert isinstance(tab.col_4.type, sqla.Integer) 40 | assert isinstance(tab.col_5.type, sqla.Integer) 41 | assert isinstance(tab.col_6.type, sqla.Numeric) 42 | assert isinstance(tab.col_7.type, sqla.Numeric) 43 | assert isinstance(tab.col_8.type, sqla.REAL) 44 | assert isinstance(tab.col_9.type, DOUBLE_PRECISION) 45 | assert isinstance(tab.col_10.type, sqla.Boolean) 46 | assert isinstance(tab.col_11.type, ARRAY) 47 | 48 | 49 | def test_reflect_gql_boolean(engine, session): 50 | session.execute(SQL_UP) 51 | session.commit() 52 | 53 | TableBase = automap_base(metadata=MetaData()) 54 | 55 | TableBase.prepare(engine, reflect=True, schema="public") 56 | tab = TableBase.classes["refl_numer"] 57 | 58 | gql_bool = convert_type(tab.col_10.type) 59 | assert gql_bool == alias.Boolean 60 | -------------------------------------------------------------------------------- /docs/comment_directives.md: -------------------------------------------------------------------------------- 1 | # Comment Directives (Experimental) 2 | 3 | Some SQL comments are interpreted as directives and impact how nebulo reflects database entities. 4 | 5 | ## Exclude 6 | 7 | *Usage* 8 | 9 | `@exclude ...` 10 | 11 | 12 | *Description* 13 | 14 | The exclude directive can be applied to tables, views, foreign keys, table columns, or view columns. The mutation or query you'd like the entity to be omitted from may be described with a combination of the following. 15 | 16 | **Table/View Allowed Params:** 17 | 18 | - create 19 | - read 20 | - read_one 21 | - read_all 22 | - update 23 | - delete 24 | 25 | **Column Allowed Params:** 26 | 27 | - create 28 | - read 29 | - update 30 | - delete 31 | 32 | **Foreign Key Allowed Params:** 33 | - read 34 | 35 | 36 | For example, the directive `@exclude delete, update` would make the entity immutable. 37 | 38 | Note: 39 | - that `read` is equivalent to `read_one` and `read_all` together. 40 | - `@exclude connection` prevents any associations to the table via foreign keys 41 | 42 | #### Example 43 | 44 | The following examples show how you could exclude a column named `password_hash` from being exposed through the GraphQL API using SQL or SQLAlchemy. 45 | 46 | 47 | **SQL** 48 | ```sql 49 | create table account ( 50 | username text not null primary key, 51 | password_hash text not null, 52 | ); 53 | 54 | comment on column account.password_hash is E'@exclude create, read, update, delete'; 55 | ``` 56 | 57 | 58 | **SQLAlchemy** 59 | ```python 60 | class Account(Base): 61 | __tablename__ = "account" 62 | 63 | username = Column(Text, primary_key=False) 64 | password_hash = Column(Text, comment="@exclude create, read, update, delete") 65 | ``` 66 | 67 | ## Name 68 | 69 | *Usage* 70 | 71 | `@name ` 72 | 73 | 74 | *Description* 75 | 76 | 77 | The name directive can be applied to tables, and table columns to override the GraphQL type's associated with the commented entity. 78 | 79 | For example, the directive `@name ShippingAddress` on an `address` table would result in the GraphQL type being named `ShippingAddress`. 80 | -------------------------------------------------------------------------------- /src/nebulo/gql/relay/cursor.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=comparison-with-callable 2 | from __future__ import annotations 3 | 4 | import json 5 | import typing 6 | 7 | from nebulo.gql.alias import ScalarType 8 | from nebulo.sql.inspect import get_primary_key_columns, get_table_name 9 | from nebulo.sql.statement_helpers import literal_string 10 | from nebulo.text_utils.base64 import from_base64, to_base64 11 | from sqlalchemy import func 12 | from sqlalchemy.sql.selectable import Alias 13 | 14 | 15 | class CursorStructure(typing.NamedTuple): 16 | table_name: str 17 | values: typing.Dict[str, typing.Any] 18 | 19 | @classmethod 20 | def from_dict(cls, contents: typing.Dict) -> CursorStructure: 21 | res = cls(table_name=contents["table_name"], values=contents["values"]) 22 | return res 23 | 24 | def to_dict(self) -> typing.Dict[str, typing.Any]: 25 | return {"table_name": self.table_name, "values": self.values} 26 | 27 | def serialize(self) -> str: 28 | ser = to_base64(json.dumps(self.to_dict())) 29 | return ser 30 | 31 | @classmethod 32 | def deserialize(cls, serialized: str) -> CursorStructure: 33 | contents = json.loads(from_base64(serialized)) 34 | return cls.from_dict(contents) 35 | 36 | 37 | def serialize(value: typing.Union[CursorStructure, typing.Dict]): 38 | node_id = CursorStructure.from_dict(value) if isinstance(value, dict) else value 39 | return node_id.serialize() 40 | 41 | 42 | def to_cursor_sql(sqla_model, query_elem: Alias): 43 | table_name = get_table_name(sqla_model) 44 | 45 | pkey_cols = get_primary_key_columns(sqla_model) 46 | 47 | # Columns selected from query element 48 | vals = [] 49 | for col in pkey_cols: 50 | col_name = str(col.name) 51 | vals.extend([literal_string(col_name), query_elem.c[col_name]]) 52 | 53 | return func.jsonb_build_object( 54 | literal_string("table_name"), 55 | literal_string(table_name), 56 | literal_string("values"), 57 | func.jsonb_build_object(*vals), 58 | ) 59 | 60 | 61 | Cursor = ScalarType( 62 | "Cursor", 63 | description="Pagination point", 64 | serialize=serialize, 65 | parse_value=CursorStructure.deserialize, 66 | parse_literal=lambda x: CursorStructure.deserialize(x.value), 67 | ) 68 | -------------------------------------------------------------------------------- /src/test/app_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import typing 5 | 6 | if typing.TYPE_CHECKING: 7 | from starlette.applications import Starlette 8 | 9 | SQL_UP = """ 10 | CREATE TABLE account ( 11 | id serial primary key, 12 | name text not null, 13 | created_at timestamp without time zone default (now() at time zone 'utc') 14 | ); 15 | 16 | INSERT INTO account (id, name) VALUES 17 | (1, 'oliver'), 18 | (2, 'rachel'), 19 | (3, 'sophie'), 20 | (4, 'buddy'); 21 | 22 | 23 | CREATE FUNCTION to_lower(some_text text) returns text as 24 | $$ select lower(some_text) $$ language sql immutable; 25 | """ 26 | 27 | 28 | def test_app_has_route(app_builder): 29 | app: Starlette = app_builder(SQL_UP) 30 | routes: typing.List[str] = [x.path for x in app.routes] 31 | assert "/" in routes 32 | 33 | 34 | def test_app_serves_graphiql(client_builder): 35 | client = client_builder(SQL_UP) 36 | headers = {"Accept": "text/html"} 37 | with client: 38 | resp = client.get("/graphiql", headers=headers) 39 | print(resp) 40 | assert resp.status_code == 200 41 | 42 | 43 | def test_app_serves_graphql_query_from_application_json(client_builder): 44 | client = client_builder(SQL_UP) 45 | 46 | query = f""" 47 | {{ 48 | allAccounts {{ 49 | edges {{ 50 | node {{ 51 | id 52 | name 53 | createdAt 54 | }} 55 | }} 56 | }} 57 | }} 58 | """ 59 | 60 | with client: 61 | resp = client.post("/", json={"query": query}) 62 | assert resp.status_code == 200 63 | print(resp.text) 64 | 65 | payload = json.loads(resp.text) 66 | assert "data" in payload 67 | assert len(payload["data"]["allAccounts"]["edges"]) == 4 68 | 69 | 70 | def test_app_serves_mutation_function(client_builder): 71 | client = client_builder(SQL_UP) 72 | 73 | query = """ 74 | query { 75 | toLower(some_text: "AbC") 76 | } 77 | """ 78 | 79 | with client: 80 | resp = client.post("/", json={"query": query}) 81 | assert resp.status_code == 200 82 | print(resp.text) 83 | 84 | payload = json.loads(resp.text) 85 | assert "data" in payload 86 | assert payload["data"]["toLower"] == "abc" 87 | -------------------------------------------------------------------------------- /src/nebulo/gql/convert/delete.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from functools import lru_cache 4 | 5 | from nebulo.config import Config 6 | from nebulo.gql.alias import ( 7 | DeleteInputType, 8 | DeletePayloadType, 9 | Field, 10 | InputObjectType, 11 | NonNull, 12 | ObjectType, 13 | String, 14 | TableInputType, 15 | ) 16 | from nebulo.gql.relay.node_interface import ID 17 | from nebulo.gql.resolve.resolvers.default import default_resolver 18 | from nebulo.sql.table_base import TableProtocol 19 | 20 | """ 21 | deleteAccount(input: DeleteAccountInput!): 22 | deleteAccountPayload 23 | """ 24 | 25 | 26 | @lru_cache() 27 | def delete_entrypoint_factory(sqla_model: TableProtocol, resolver) -> Field: 28 | """deleteAccount""" 29 | relevant_type_name = Config.table_type_name_mapper(sqla_model) 30 | name = f"delete{relevant_type_name}" 31 | args = {"input": NonNull(delete_input_type_factory(sqla_model))} 32 | payload = delete_payload_factory(sqla_model) 33 | return { 34 | name: Field( 35 | payload, 36 | args=args, 37 | resolve=resolver, 38 | description=f"Delete a single {relevant_type_name} using its globally unique id and a patch.", 39 | ) 40 | } 41 | 42 | 43 | @lru_cache() 44 | def delete_input_type_factory(sqla_model: TableProtocol) -> InputObjectType: 45 | """DeleteAccountInput!""" 46 | relevant_type_name = Config.table_type_name_mapper(sqla_model) 47 | result_name = f"Delete{relevant_type_name}Input" 48 | 49 | input_object_name = Config.table_name_mapper(sqla_model) 50 | 51 | attrs = { 52 | "nodeId": NonNull(ID), 53 | "clientMutationId": String, 54 | } 55 | return DeleteInputType(result_name, attrs, description=f"All input for the create {relevant_type_name} mutation.") 56 | 57 | 58 | @lru_cache() 59 | def delete_payload_factory(sqla_model: TableProtocol) -> InputObjectType: 60 | """DeleteAccountPayload""" 61 | 62 | relevant_type_name = Config.table_type_name_mapper(sqla_model) 63 | result_name = f"Delete{relevant_type_name}Payload" 64 | 65 | attrs = {"clientMutationId": Field(String, resolve=default_resolver), "nodeId": ID} 66 | 67 | return DeletePayloadType( 68 | result_name, attrs, description=f"The output of our delete {relevant_type_name} mutation", sqla_model=sqla_model 69 | ) 70 | -------------------------------------------------------------------------------- /src/nebulo/server/starlette.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from nebulo.gql.sqla_to_gql import sqla_models_to_graphql_schema 4 | from nebulo.server.exception import http_exception 5 | from nebulo.server.routes import get_graphiql_route, get_graphql_route 6 | from nebulo.sql.reflection.manager import reflect_sqla_models 7 | from sqlalchemy import create_engine 8 | from sqlalchemy.ext.asyncio import create_async_engine 9 | from starlette.applications import Starlette 10 | from starlette.exceptions import HTTPException 11 | from starlette.middleware import Middleware 12 | from starlette.middleware.cors import CORSMiddleware 13 | 14 | 15 | def create_app( 16 | connection: str, 17 | schema: str = "public", 18 | jwt_identifier: Optional[str] = None, 19 | jwt_secret: Optional[str] = None, 20 | default_role: Optional[str] = None, 21 | ) -> Starlette: 22 | """Instantiate the Starlette app""" 23 | 24 | if not (jwt_identifier is not None) == (jwt_secret is not None): 25 | raise Exception("jwt_token_identifier and jwt_secret must be provided together") 26 | 27 | async_connection = connection.replace("postgresql://", "postgresql+asyncpg://") 28 | engine = create_async_engine(async_connection) 29 | # Reflect database to sqla models 30 | sqla_engine = create_engine(connection) 31 | sqla_models, sql_functions = reflect_sqla_models(engine=sqla_engine, schema=schema) 32 | 33 | # Convert sqla models to graphql schema 34 | gql_schema = sqla_models_to_graphql_schema( 35 | sqla_models, 36 | sql_functions, 37 | jwt_identifier=jwt_identifier, 38 | jwt_secret=jwt_secret, 39 | ) 40 | 41 | graphql_path = "/" 42 | 43 | graphql_route = get_graphql_route( 44 | gql_schema=gql_schema, 45 | engine=engine, 46 | jwt_secret=jwt_secret, 47 | default_role=default_role, 48 | path=graphql_path, 49 | name="graphql", 50 | ) 51 | 52 | graphiql_route = get_graphiql_route(graphiql_path="/graphiql", graphql_path=graphql_path, name="graphiql") 53 | 54 | _app = Starlette( 55 | routes=[graphql_route, graphiql_route], 56 | middleware=[Middleware(CORSMiddleware, allow_origins=["*"])], 57 | exception_handlers={HTTPException: http_exception}, 58 | on_startup=[engine.connect], 59 | on_shutdown=[engine.dispose], 60 | ) 61 | 62 | return _app 63 | -------------------------------------------------------------------------------- /docs/performance.md: -------------------------------------------------------------------------------- 1 | # Performance 2 | 3 | **Performance Enabling Features:** 4 | 5 | * All queries are handled in a single round-trip to the database so there are no [N+1 issues](https://stackoverflow.com/questions/97197/what-is-the-n1-selects-problem-in-orm-object-relational-mapping) 6 | * SQL queries only fetch requested fields 7 | * SQL queries return JSON which significantly reduces database IO when joins are present 8 | * Fully async 9 | 10 | 11 | **Benchmarks** 12 | 13 | Performance depends on network, number of workers, log level etc. Despite all that, here are rough figures with Postgres and the web server running on a mid-tier 2017 Macbook Pro. 14 | 15 | ```text 16 | Benchmarking 0.0.0.0 (be patient) 17 | 18 | 19 | Server Software: uvicorn 20 | Server Hostname: 0.0.0.0 21 | Server Port: 5034 22 | 23 | Document Path: / 24 | Document Length: 310 bytes 25 | 26 | Concurrency Level: 10 27 | Time taken for tests: 5.571 seconds 28 | Complete requests: 1000 29 | Failed requests: 0 30 | Total transferred: 436000 bytes 31 | Total body sent: 245000 32 | HTML transferred: 310000 bytes 33 | Requests per second: 179.49 [#/sec] (mean) 34 | Time per request: 55.712 [ms] (mean) 35 | Time per request: 5.571 [ms] (mean, across all concurrent requests) 36 | Transfer rate: 76.43 [Kbytes/sec] received 37 | 42.95 kb/s sent 38 | 119.37 kb/s total 39 | 40 | Connection Times (ms) 41 | min mean[+/-sd] median max 42 | Connect: 0 0 0.0 0 0 43 | Processing: 17 55 18.7 52 98 44 | Waiting: 17 55 18.7 51 98 45 | Total: 18 55 18.7 52 98 46 | 47 | Percentage of the requests served within a certain time (ms) 48 | 50% 52 49 | 66% 60 50 | 75% 65 51 | 80% 69 52 | 90% 80 53 | 95% 90 54 | 98% 94 55 | 99% 95 56 | 100% 98 (longest request) 57 | ``` 58 | 59 | So approximately 180 requests/second responding in sub 100 milliseconds. 60 | 61 | Note that under normal load, response times are significantly faster. 62 | 63 | ```text 64 | Percentage of the requests served within a certain time (ms) 65 | 50% 17 66 | 66% 19 67 | 75% 19 68 | 80% 20 69 | 90% 21 70 | 95% 24 71 | 98% 26 72 | 99% 30 73 | 100% 38 (longest request) 74 | ``` 75 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import find_packages, setup 4 | 5 | 6 | def read_package_variable(key, filename="__init__.py"): 7 | """Read the value of a variable from the package without importing.""" 8 | module_path = os.path.join("src/nebulo", filename) 9 | with open(module_path) as module: 10 | for line in module: 11 | parts = line.strip().split(" ", 2) 12 | if parts[:-1] == [key, "="]: 13 | return parts[-1].strip("'").strip('"') 14 | return None 15 | 16 | 17 | setup( 18 | name="nebulo", 19 | version=read_package_variable("VERSION"), 20 | description="Nebulo: GraphQL API for PostgreSQL", 21 | author="Oliver Rice", 22 | author_email="oliver@oliverrice.com", 23 | license="MIT", 24 | url="https://github.com/olirice/nebulo", 25 | project_urls={ 26 | "Documentation": "https://olirice.github.io/nebulo/", 27 | "Source Code": "https://github.com/olirice/nebulo", 28 | }, 29 | keywords="graphql sqlalchemy sql api python", 30 | classifiers=[ 31 | "Intended Audience :: Developers", 32 | "Programming Language :: Python :: 3.7", 33 | "Programming Language :: Python :: 3.8", 34 | "Programming Language :: Python :: 3.9", 35 | ], 36 | python_requires=">=3.7", 37 | package_dir={"": "src"}, 38 | packages=find_packages("src"), 39 | include_package_data=True, 40 | entry_points={ 41 | "console_scripts": ["nebulo=nebulo.cli:main", "neb=nebulo.cli:main"], 42 | "pygments.lexers": ["graphqllexer=nebulo.lexer:GraphQLLexer"], 43 | }, 44 | install_requires=[ 45 | "aiofiles==0.5.*", 46 | "appdirs==1.4.3", 47 | "asyncpg", 48 | "cachetools==4.0.*", 49 | "click>=8", 50 | 'dataclasses; python_version<"3.7"', 51 | "flupy>=1.2", 52 | "graphql-core==3.1.*", 53 | "inflect==4.1.*", 54 | "parse==1.15.*", 55 | "psycopg2-binary>2.8", 56 | "pyjwt==1.7.*", 57 | "starlette==0.14.*", 58 | "sqlalchemy<2", 59 | "typing-extensions", 60 | "uvicorn>=0.20", 61 | ], 62 | extras_require={ 63 | "test": ["pytest", "pytest-cov", "requests", "pytest-asyncio"], 64 | "dev": ["pylint", "black", "sqlalchemy-stubs", "pre-commit"], 65 | "nvim": ["neovim", "python-language-server"], 66 | "docs": ["mkdocs", "pygments", "pymdown-extensions", "mkautodoc"], 67 | }, 68 | ) 69 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # VIM 10 | *.swp 11 | 12 | # Mac 13 | .DS 14 | .DS_Store 15 | 16 | # VSCode 17 | .vscode/ 18 | 19 | # PyCharm 20 | .idea/ 21 | 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | pip-wheel-metadata/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .nox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | db.sqlite3-journal 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # pyenv 98 | .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # celery beat schedule file 108 | celerybeat-schedule 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | -------------------------------------------------------------------------------- /src/nebulo/gql/relay/node_interface.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import typing 5 | 6 | from nebulo.gql.alias import Field, InterfaceType, NonNull, ScalarType 7 | from nebulo.sql.inspect import get_primary_key_columns, get_table_name 8 | from nebulo.sql.statement_helpers import literal_string 9 | from nebulo.text_utils.base64 import from_base64, to_base64 10 | from sqlalchemy import func 11 | from sqlalchemy.sql.selectable import Alias 12 | 13 | 14 | class NodeIdStructure(typing.NamedTuple): 15 | table_name: str 16 | values: typing.Dict[str, typing.Any] 17 | 18 | @classmethod 19 | def from_dict(cls, contents: typing.Dict) -> NodeIdStructure: 20 | res = cls(table_name=contents["table_name"], values=contents["values"]) 21 | return res 22 | 23 | def to_dict(self) -> typing.Dict[str, typing.Any]: 24 | return {"table_name": self.table_name, "values": self.values} 25 | 26 | def serialize(self) -> str: 27 | ser = to_base64(json.dumps(self.to_dict())) 28 | return ser 29 | 30 | @classmethod 31 | def deserialize(cls, serialized: str) -> NodeIdStructure: 32 | contents = json.loads(from_base64(serialized)) 33 | return cls.from_dict(contents) 34 | 35 | 36 | def serialize(value: typing.Union[NodeIdStructure, typing.Dict]): 37 | node_id = NodeIdStructure.from_dict(value) if isinstance(value, dict) else value 38 | return node_id.serialize() 39 | 40 | 41 | def to_node_id_sql(sqla_model, query_elem: Alias): 42 | table_name = get_table_name(sqla_model) 43 | 44 | pkey_cols = get_primary_key_columns(sqla_model) 45 | 46 | # Columns selected from query element 47 | vals = [] 48 | for col in pkey_cols: 49 | col_name = str(col.name) 50 | vals.extend([literal_string(col_name), query_elem.c[col_name]]) 51 | 52 | return func.jsonb_build_object( 53 | literal_string("table_name"), 54 | literal_string(table_name), 55 | literal_string("values"), 56 | func.jsonb_build_object(*vals), 57 | ) 58 | 59 | 60 | ID = ScalarType( 61 | "ID", 62 | description="Unique ID for node", 63 | serialize=serialize, 64 | parse_value=NodeIdStructure.deserialize, 65 | parse_literal=lambda x: NodeIdStructure.deserialize(x.value), 66 | ) 67 | 68 | NodeInterface = InterfaceType( 69 | "NodeInterface", 70 | description="An object with a nodeId", 71 | fields={"nodeId": Field(NonNull(ID), description="The global id of the object.", resolve=None)}, 72 | # Maybe not necessary 73 | resolve_type=lambda *args, **kwargs: None, 74 | ) 75 | -------------------------------------------------------------------------------- /src/nebulo/cli.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import click 4 | import uvicorn 5 | from graphql.utilities import print_schema 6 | from nebulo import VERSION 7 | from nebulo.env import EnvManager 8 | from sqlalchemy import create_engine 9 | 10 | 11 | @click.group() 12 | @click.version_option(version=VERSION) 13 | def main(**kwargs): 14 | pass # pragma: no cover 15 | 16 | 17 | @main.command() 18 | @click.option("-c", "--connection", help="Database connection string") 19 | @click.option("-p", "--port", default=5034, help="Web server port") 20 | @click.option("-h", "--host", default="0.0.0.0", help="Host address") 21 | @click.option("-w", "--workers", default=1, help="Number of parallel workers") 22 | @click.option("-s", "--schema", default="public", help="SQL schema name") 23 | @click.option("--jwt-identifier", default=None, help='JWT composite type identifier e.g. "public.jwt"') 24 | @click.option("--jwt-secret", default=None, help="Secret key for JWT encryption") 25 | @click.option("--reload/--no-reload", default=False, help="Reload if source files change") 26 | @click.option("--default-role", type=str, default=None, help="Default PostgreSQL role for anonymous users") 27 | def run(connection, schema, host, port, jwt_identifier, jwt_secret, reload, workers, default_role): 28 | """Run the GraphQL Web Server""" 29 | if reload and workers > 1: 30 | click.echo("Reload not supported with workers > 1") 31 | else: 32 | 33 | with EnvManager( 34 | NEBULO_CONNECTION=connection, 35 | NEBULO_SCHEMA=schema, 36 | NEBULO_JWT_IDENTIFIER=jwt_identifier, 37 | NEBULO_JWT_SECRET=jwt_secret, 38 | NEBULO_DEFAULT_ROLE=default_role, 39 | ): 40 | 41 | uvicorn.run("nebulo.server.app:APP", host=host, workers=workers, port=port, log_level="info", reload=reload) 42 | 43 | 44 | @main.command() 45 | @click.option("-c", "--connection", help="Database connection string") 46 | @click.option("-s", "--schema", default="public", help="SQL schema name") 47 | @click.option("-o", "--out-file", type=click.File("w"), default=None, help="Output file path") 48 | def dump_schema(connection, schema, out_file): 49 | """Dump the GraphQL Schema to stdout or file""" 50 | from nebulo.gql.sqla_to_gql import sqla_models_to_graphql_schema 51 | from nebulo.sql.reflection.manager import reflect_sqla_models 52 | 53 | engine = create_engine(connection) 54 | sqla_models, sql_functions = reflect_sqla_models(engine, schema=schema) 55 | schema = sqla_models_to_graphql_schema(sqla_models, sql_functions) 56 | schema_str = print_schema(schema) 57 | click.echo(schema_str, file=out_file) 58 | -------------------------------------------------------------------------------- /docs/public_api.md: -------------------------------------------------------------------------------- 1 | # Public API 2 | 3 | Nebulo exposes a minimal public API to allow the internals to change freely. The public API is defined as anything documented on this page. 4 | 5 | For a complete example see [the SQLAlchemy interop example](sqlalchemy_interop.md) 6 | 7 | ``` python 8 | from databases import Database 9 | from nebulo.gql.sqla_to_gql import sqla_models_to_graphql_schema 10 | from nebulo.server.exception import http_exception 11 | from nebulo.server.routes import GRAPHIQL_STATIC_FILES, get_graphql_route, graphiql_route 12 | from starlette.applications import Starlette 13 | from starlette.exceptions import HTTPException 14 | from starlette.routing import Mount, Route 15 | 16 | from myapp.sqla_models import Author, Book 17 | from myapp.config import DATABASE_URI 18 | 19 | 20 | def create_app(connection_str, sqla_models) -> Starlette: 21 | """Create the Starlette app""" 22 | 23 | database = Database(connection) 24 | 25 | # Convert sqla models to graphql schema 26 | gql_schema = sqla_models_to_graphql_schema(sqla_models=sqla_models) 27 | 28 | graphql_path = '/' 29 | 30 | # Build the Starlette GraphQL Route 31 | graphql_route = get_graphql_route( 32 | gql_schema=gql_schema, 33 | database=database, 34 | jwt_secret=jwt_secret, 35 | default_role=default_role, 36 | path=graphql_path, 37 | name='graphql' 38 | ) 39 | 40 | # Build the Starlette GraphiQL Route and StaticFiles 41 | graphiql_route = get_graphiql_route( 42 | graphiql_path='/graphiql', 43 | graphql_path=graphql_path, 44 | name='graphiql' 45 | ) 46 | 47 | # Instantiate the Starlette app 48 | _app = Starlette( 49 | routes=[graphql_route, graphiql_route], 50 | exception_handlers={HTTPException: http_exception}, 51 | on_startup=[database.connect], 52 | on_shutdown=[database.disconnect], 53 | ) 54 | 55 | return _app 56 | 57 | 58 | 59 | # Instantiate the app 60 | APP = create_app(connection_str=DATABASE_URI, sqla_models=[Author, Book]) 61 | ``` 62 | 63 | ### Reflection 64 | 65 | Utilities for reflecting SQL entities as GraphQL entities. 66 | 67 | ##### SQLA to GraphQL 68 | ---- 69 | 70 | ::: nebulo.gql.sqla_to_gql.sqla_models_to_graphql_schema 71 | :docstring: 72 | 73 | 74 | ### Server 75 | 76 | Helpers to serve the GraphQL schema. 77 | 78 | ##### Routes 79 | ---- 80 | 81 | ::: nebulo.server.routes.get_graphql_route 82 | :docstring: 83 | 84 | ---- 85 | 86 | ::: nebulo.server.routes.get_graphiql_route 87 | :docstring: 88 | 89 | ##### Exception Handling 90 | ---- 91 | 92 | ::: nebulo.server.exception.http_exception 93 | :docstring: 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /src/nebulo/sql/reflection/constraint_comments.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from functools import lru_cache 4 | from typing import Dict, Optional 5 | 6 | from nebulo.sql.inspect import get_table_name 7 | from sqlalchemy import text as sql_text 8 | from sqlalchemy.engine import Engine 9 | from sqlalchemy.sql.schema import Constraint 10 | 11 | __all__ = ["populate_constraint_comment"] 12 | 13 | 14 | SchemaName = str 15 | TableName = str 16 | ConstraintName = str 17 | Comment = str 18 | 19 | CommentMap = Dict[SchemaName, Dict[TableName, Dict[ConstraintName, Comment]]] 20 | 21 | 22 | def populate_constraint_comment(engine: Engine, constraint: Constraint) -> None: 23 | """Adds SQL comments on a constraint to the SQLAlchemy constraint's 24 | Constraint.info['comment'] dictionary 25 | """ 26 | 27 | schema: Optional[str] = constraint.table.schema 28 | table_name: str = get_table_name(constraint.table) 29 | constraint_name: Optional[str] = constraint.name 30 | 31 | if schema is None or constraint_name is None: 32 | return 33 | 34 | comment_map: CommentMap = reflect_all_constraint_comments(engine=engine, schema=schema) 35 | comment: Optional[str] = comment_map.get(schema, {}).get(table_name, {}).get(constraint_name) 36 | 37 | # constraint.info is "Optional[Mapping[str, Any]]" 38 | if not hasattr(constraint, "info"): 39 | constraint.info = {"comment": comment} 40 | else: 41 | constraint.info["comment"] = comment # type: ignore 42 | 43 | return 44 | 45 | 46 | @lru_cache() 47 | def reflect_all_constraint_comments(engine, schema: str) -> CommentMap: 48 | """Collect a mapping of constraint comments""" 49 | 50 | sql = sql_text( 51 | """ 52 | select 53 | c.relnamespace::regnamespace::text schemaname, 54 | c.relname tablename, 55 | t.conname constraintname, 56 | d.description comment_body 57 | from pg_class c 58 | join pg_constraint t 59 | on c.oid = t.conrelid 60 | join pg_description d 61 | on t.oid = d.objoid 62 | and t.tableoid = d.classoid 63 | where 64 | c.relnamespace::regnamespace::text = :schema 65 | """ 66 | ) 67 | 68 | results = engine.execute(sql, schema=schema).fetchall() 69 | 70 | comment_map: CommentMap = {} 71 | 72 | for schema_name, table_name, constraint_name, comment in results: 73 | comment_map[schema_name] = comment_map.get(schema_name, {}) 74 | comment_map[schema_name][table_name] = comment_map[schema_name].get(table_name, {}) 75 | comment_map[schema_name][table_name][constraint_name] = comment 76 | 77 | return comment_map 78 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Nebulo 2 | 3 |

4 | 5 | Test Status 6 | 7 | 8 | Pre-commit Status 9 | 10 | 11 | 12 | Codestyle Black 13 | 14 | 15 |

16 |

17 | Python version 18 | License 19 | PyPI version 20 | Download count 21 |

22 | 23 | --- 24 | 25 | **Documentation**: https://olirice.github.io/nebulo 26 | 27 | **Source Code**: https://github.com/olirice/nebulo 28 | 29 | --- 30 | 31 | **Instant GraphQL API for PostgreSQL** 32 | 33 | [Reflect](https://en.wikipedia.org/wiki/Reflection_(computer_programming)) a highly performant [GraphQL](https://graphql.org/learn/) API from an existing [PostgreSQL](https://www.postgresql.org/) database. 34 | 35 | Nebulo is a python library for building GraphQL APIs on top of PostgreSQL. It has a command line interface for reflecting databases with 0 code or can be added to existing [SQLAlchemy](https://www.sqlalchemy.org/) projects. 36 | 37 | 38 | ## TL;DR 39 | 40 | First, install nebulo 41 | ```shell 42 | $ pip install nebulo 43 | ``` 44 | 45 | Then point the CLI at an existing PostgreSQL database using connection string format `postgresql://:@:/` 46 | ```shell 47 | neb run -c postgresql://nebulo_user:password@localhost:4443/nebulo_db 48 | ``` 49 | 50 | Visit your shiny new GraphQL API at [http://localhost:5034/graphiql](http://localhost:5034/graphiql) 51 | 52 | ![graphiql image](images/graphiql.png) 53 | 54 | Next, check out the [quickstart](quickstart.md) guide for a small end-to-end example. 55 | 56 |

—— ——

57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Nebulo 2 | 3 |

4 | 5 | Test Status 6 | 7 | 8 | Pre-commit Status 9 | 10 | 11 | 12 | Codestyle Black 13 | 14 | 15 |

16 |

17 | Python version 18 | License 19 | PyPI version 20 | Download count 21 |

22 | 23 | --- 24 | 25 | **Documentation**: https://olirice.github.io/nebulo 26 | 27 | **Source Code**: https://github.com/olirice/nebulo 28 | 29 | --- 30 | 31 | **Instant GraphQL API for PostgreSQL** 32 | 33 | [Reflect](https://en.wikipedia.org/wiki/Reflection_(computer_programming)) a highly performant [GraphQL](https://graphql.org/learn/) API from an existing [PostgreSQL](https://www.postgresql.org/) database. 34 | 35 | Nebulo is a python library for building GraphQL APIs on top of PostgreSQL. It has a command line interface for reflecting databases with 0 code or can be added to existing [SQLAlchemy](https://www.sqlalchemy.org/) projects. 36 | 37 | 38 | ## TL;DR 39 | 40 | First, install nebulo 41 | ```shell 42 | $ pip install nebulo 43 | ``` 44 | 45 | Then point the CLI at an existing PostgreSQL database using connection string format `postgresql://:@:/` 46 | ```shell 47 | neb run -c postgresql://nebulo_user:password@localhost:4443/nebulo_db 48 | ``` 49 | 50 | Visit your shiny new GraphQL API at [http://localhost:5034/graphiql](http://localhost:5034/graphiql) 51 | 52 | 53 | ![graphiql image](docs/images/graphiql.png) 54 | 55 | Next, check out the [docs](https://olirice.github.io/nebulo/quickstart/) guide for a small end-to-end example. 56 | 57 |

—— ——

58 | -------------------------------------------------------------------------------- /src/test/integration_relationship_test.py: -------------------------------------------------------------------------------- 1 | from nebulo.gql.relay.node_interface import NodeIdStructure 2 | 3 | SQL_UP = """ 4 | CREATE TABLE account ( 5 | id serial primary key, 6 | name text not null, 7 | created_at timestamp without time zone default (now() at time zone 'utc') 8 | ); 9 | 10 | INSERT INTO account (id, name) VALUES 11 | (1, 'oliver'), 12 | (2, 'rachel'), 13 | (3, 'sophie'), 14 | (4, 'buddy'); 15 | 16 | 17 | create table offer ( 18 | id serial primary key, --integer primary key autoincrement, 19 | currency text, 20 | account_id_not_null int not null, 21 | account_id_nullable int, 22 | 23 | constraint fk_offer_account_id_not_null 24 | foreign key (account_id_not_null) 25 | references account (id), 26 | 27 | constraint fk_offer_account_id_nullable 28 | foreign key (account_id_nullable) 29 | references account (id) 30 | 31 | ); 32 | 33 | INSERT INTO offer (currency, account_id_not_null, account_id_nullable) VALUES 34 | ('usd', 2, 2), 35 | ('gbp', 2, 2), 36 | ('eur', 3, null), 37 | ('jpy', 4, 4); 38 | """ 39 | 40 | 41 | def test_query_one_to_many(client_builder): 42 | client = client_builder(SQL_UP) 43 | account_id = 2 44 | node_id = NodeIdStructure(table_name="account", values={"id": account_id}).serialize() 45 | gql_query = f""" 46 | {{ 47 | account(nodeId: "{node_id}") {{ 48 | id 49 | offersByIdToAccountIdNullable {{ 50 | edges {{ 51 | node {{ 52 | id 53 | currency 54 | }} 55 | }} 56 | }} 57 | }} 58 | }} 59 | """ 60 | with client: 61 | resp = client.post("/", json={"query": gql_query}) 62 | assert resp.status_code == 200 63 | 64 | result = resp.json() 65 | assert result["errors"] == [] 66 | assert result["data"]["account"]["id"] == account_id 67 | offers_by_id = result["data"]["account"]["offersByIdToAccountIdNullable"] 68 | currencies = {x["node"]["currency"] for x in offers_by_id["edges"]} 69 | assert "usd" in currencies and "gbp" in currencies 70 | 71 | # Fails because sql resolver not applying join correctly 72 | assert len(offers_by_id["edges"]) == 2 73 | 74 | 75 | def test_query_many_to_one(client_builder): 76 | client = client_builder(SQL_UP) 77 | account_id = 2 78 | gql_query = """ 79 | { 80 | allOffers { 81 | edges { 82 | node { 83 | id 84 | accountByAccountIdNotNullToId { 85 | name 86 | } 87 | } 88 | } 89 | } 90 | } 91 | """ 92 | with client: 93 | resp = client.post("/", json={"query": gql_query}) 94 | assert resp.status_code == 200 95 | 96 | result = resp.json() 97 | assert result["errors"] == [] 98 | -------------------------------------------------------------------------------- /src/nebulo/gql/convert/connection.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from functools import lru_cache 4 | 5 | from nebulo.config import Config 6 | from nebulo.gql.alias import Argument, ConnectionType, EdgeType, Field, InputObjectType, Int, List, NonNull 7 | from nebulo.gql.convert.column import convert_column_to_input 8 | from nebulo.gql.relay.cursor import Cursor 9 | from nebulo.gql.relay.page_info import PageInfo 10 | from nebulo.gql.resolve.resolvers.default import default_resolver 11 | from nebulo.sql.inspect import get_columns 12 | from nebulo.sql.table_base import TableProtocol 13 | 14 | __all__ = ["connection_field_factory"] 15 | 16 | 17 | @lru_cache() 18 | def connection_field_factory(sqla_model: TableProtocol, resolver, not_null=False) -> Field: 19 | relevant_type_name = Config.table_type_name_mapper(sqla_model) 20 | connection = connection_factory(sqla_model) 21 | condition = condition_factory(sqla_model) 22 | args = { 23 | "first": Argument(Int, description="", out_name=None), 24 | "last": Argument(Int), 25 | "before": Argument(Cursor), 26 | "after": Argument(Cursor), 27 | "condition": Argument(condition), 28 | } 29 | return Field( 30 | NonNull(connection) if not_null else connection, 31 | args=args, 32 | resolve=resolver, 33 | description=f"Reads and enables pagination through a set of {relevant_type_name}", 34 | ) 35 | 36 | 37 | @lru_cache() 38 | def edge_factory(sqla_model: TableProtocol) -> EdgeType: 39 | from .table import table_factory 40 | 41 | name = Config.table_type_name_mapper(sqla_model) + "Edge" 42 | 43 | def build_attrs(): 44 | return {"cursor": Field(Cursor), "node": Field(table_factory(sqla_model))} 45 | 46 | edge = EdgeType(name=name, fields=build_attrs, description="", sqla_model=sqla_model) 47 | return edge 48 | 49 | 50 | @lru_cache() 51 | def condition_factory(sqla_model: TableProtocol) -> InputObjectType: 52 | result_name = f"{Config.table_name_mapper(sqla_model)}Condition" 53 | 54 | attrs = {} 55 | for column in get_columns(sqla_model): 56 | field_key = Config.column_name_mapper(column) 57 | attrs[field_key] = convert_column_to_input(column) 58 | return InputObjectType(result_name, attrs, description="") 59 | 60 | 61 | @lru_cache() 62 | def connection_factory(sqla_model: TableProtocol) -> ConnectionType: 63 | name = Config.table_type_name_mapper(sqla_model) + "Connection" 64 | 65 | def build_attrs(): 66 | edge = edge_factory(sqla_model) 67 | return { 68 | "edges": Field(NonNull(List(NonNull(edge))), resolve=default_resolver), 69 | "pageInfo": Field(NonNull(PageInfo), resolve=default_resolver), 70 | "totalCount": Field(NonNull(Int), resolve=default_resolver), 71 | } 72 | 73 | return_type = ConnectionType(name=name, fields=build_attrs, description="", sqla_model=sqla_model) 74 | return return_type 75 | -------------------------------------------------------------------------------- /src/nebulo/sql/reflection/manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import List, Tuple 4 | 5 | from nebulo.sql.inspect import get_constraints 6 | from nebulo.sql.reflection.constraint_comments import populate_constraint_comment 7 | from nebulo.sql.reflection.function import SQLFunction, reflect_functions 8 | from nebulo.sql.reflection.names import rename_table, rename_to_many_collection, rename_to_one_collection 9 | from nebulo.sql.reflection.types import reflect_composites 10 | from nebulo.sql.reflection.views import reflect_views 11 | from nebulo.sql.table_base import TableProtocol 12 | from sqlalchemy import MetaData, types 13 | from sqlalchemy.dialects.postgresql import base as pg_base 14 | from sqlalchemy.engine import Engine 15 | from sqlalchemy.ext.automap import automap_base 16 | 17 | 18 | def reflect_sqla_models(engine: Engine, schema: str = "public") -> Tuple[List[TableProtocol], List[SQLFunction]]: 19 | """Reflect SQLAlchemy Declarative Models from a database connection""" 20 | 21 | meta = MetaData() 22 | declarative_base = automap_base(metadata=meta) 23 | 24 | # Register event listeners to apply GQL attr keys to columns 25 | 26 | # Retrive a copy of the full type map 27 | basic_type_map = pg_base.ischema_names.copy() 28 | 29 | # Reflect composite types (not supported by sqla) 30 | composites = reflect_composites(engine, schema, basic_type_map) 31 | 32 | # Register composite types with SQLA to make them available during reflection 33 | pg_base.ischema_names.update({type_name: type_ for (type_schema, type_name), type_ in composites.items()}) 34 | 35 | # Retrive a copy of the full type map 36 | # NOTE: types are not schema namespaced so colisions can occur reflecting tables 37 | type_map = pg_base.ischema_names.copy() 38 | type_map["bool"] = types.Boolean # type: ignore 39 | 40 | # Reflect views as SQLA ORM table 41 | views = reflect_views(engine=engine, schema=schema, declarative_base=declarative_base) 42 | 43 | declarative_base.prepare( 44 | engine, 45 | reflect=True, 46 | schema=schema, 47 | classname_for_table=rename_table, 48 | name_for_scalar_relationship=rename_to_one_collection, 49 | name_for_collection_relationship=rename_to_many_collection, 50 | ) 51 | 52 | # Register tables as types so functions can return a row of a table 53 | tables = list(declarative_base.classes) + views 54 | for table in tables + views: 55 | type_map[table.__table__.name] = table 56 | 57 | # Reflect functions, allowing composite types 58 | functions = reflect_functions(engine=engine, schema=schema, type_map=type_map) 59 | 60 | # Populate constraint.info['comment'] for constraint comments becasue 61 | # sqlalchemy does not support constraint comment reflection 62 | # https://github.com/sqlalchemy/sqlalchemy/issues/5667 63 | for table in tables: 64 | for constraint in get_constraints(table): 65 | populate_constraint_comment(engine=engine, constraint=constraint) 66 | 67 | # SQLA Tables 68 | return (tables, functions) 69 | -------------------------------------------------------------------------------- /src/nebulo/server/routes/graphiql.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from pathlib import Path 4 | 5 | from starlette.requests import Request 6 | from starlette.responses import HTMLResponse 7 | from starlette.routing import Mount, Route 8 | from starlette.staticfiles import StaticFiles 9 | 10 | GRAPHIQL_STATIC_PATH = Path(os.path.abspath(__file__)).parent.parent.parent.resolve() / "static" 11 | 12 | GRAPHIQL_STATIC_FILES = StaticFiles(directory=GRAPHIQL_STATIC_PATH) 13 | 14 | 15 | def get_graphiql_route(graphiql_path: str = "/graphiql", graphql_path: str = "/", name="graphiql") -> Mount: 16 | """Return mountable routes serving GraphiQL interactive API explorer 17 | 18 | **Parameters** 19 | 20 | * **graphiql_path**: _str_ = URL path to GraphiQL html page 21 | * **graphql_path**: _str_ = URL path to the GraphQL route 22 | * **name**: _str_ = Name for the mount 23 | """ 24 | 25 | async def graphiql_endpoint(_: Request) -> HTMLResponse: 26 | """Return the HTMLResponse for GraphiQL GraphQL explorer configured to hit the correct endpoint""" 27 | html_text = build_graphiql_html(graphiql_path, graphql_path) 28 | return HTMLResponse(html_text) 29 | 30 | graphiql_html_route = Route("/", graphiql_endpoint, methods=["GET"]) 31 | graphiql_statics = Mount("/static", GRAPHIQL_STATIC_FILES, name="static") 32 | 33 | return Mount(graphiql_path, routes=[graphiql_html_route, graphiql_statics], name=name) 34 | 35 | 36 | @lru_cache() 37 | def build_graphiql_html( 38 | graphiql_route_path: str, 39 | graphql_route_path: str, 40 | ) -> str: 41 | """Return the raw HTML for GraphiQL GraphQL explorer""" 42 | 43 | text = GRAPHIQL.replace("{{GRAPHIQL_PATH}}", graphiql_route_path) 44 | text = text.replace("{{REQUEST_PATH}}", f'"{graphql_route_path}"') 45 | return text 46 | 47 | 48 | GRAPHIQL = """ 49 | 50 | 51 | Nebulo GraphiQL 52 | 53 | 54 | 55 |
56 | 57 | 61 | 65 | 69 | 70 | 88 | 89 | 90 | """ 91 | -------------------------------------------------------------------------------- /src/nebulo/gql/convert/create.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from functools import lru_cache 4 | 5 | from nebulo.config import Config 6 | from nebulo.gql.alias import ( 7 | CreateInputType, 8 | CreatePayloadType, 9 | Field, 10 | InputObjectType, 11 | NonNull, 12 | ObjectType, 13 | String, 14 | TableInputType, 15 | ) 16 | from nebulo.gql.convert.column import convert_column_to_input 17 | from nebulo.gql.relay.node_interface import ID 18 | from nebulo.gql.resolve.resolvers.default import default_resolver 19 | from nebulo.sql.inspect import get_columns 20 | from nebulo.sql.table_base import TableProtocol 21 | 22 | """ 23 | createAccount(input: CreateAccountInput!): 24 | CreateAccountPayload 25 | """ 26 | 27 | 28 | @lru_cache() 29 | def create_entrypoint_factory(sqla_model: TableProtocol, resolver) -> Field: 30 | """createAccount""" 31 | relevant_type_name = Config.table_type_name_mapper(sqla_model) 32 | name = f"create{relevant_type_name}" 33 | args = {"input": NonNull(create_input_type_factory(sqla_model))} 34 | payload = create_payload_factory(sqla_model) 35 | return {name: Field(payload, args=args, resolve=resolver, description=f"Creates a single {relevant_type_name}.")} 36 | 37 | 38 | @lru_cache() 39 | def create_input_type_factory(sqla_model: TableProtocol) -> CreateInputType: 40 | """CreateAccountInput!""" 41 | relevant_type_name = Config.table_type_name_mapper(sqla_model) 42 | result_name = f"Create{relevant_type_name}Input" 43 | 44 | input_object_name = Config.table_name_mapper(sqla_model) 45 | 46 | attrs = {"clientMutationId": String, input_object_name: NonNull(input_type_factory(sqla_model))} 47 | return CreateInputType(result_name, attrs, description=f"All input for the create {relevant_type_name} mutation.") 48 | 49 | 50 | def input_type_factory(sqla_model: TableProtocol) -> TableInputType: 51 | """AccountInput""" 52 | relevant_type_name = Config.table_type_name_mapper(sqla_model) 53 | result_name = f"{relevant_type_name}Input" 54 | 55 | attrs = {} 56 | for column in get_columns(sqla_model): 57 | if not Config.exclude_create(column): 58 | field_key = Config.column_name_mapper(column) 59 | attrs[field_key] = convert_column_to_input(column) 60 | return TableInputType(result_name, attrs, description=f"An input for mutations affecting {relevant_type_name}.") 61 | 62 | 63 | @lru_cache() 64 | def create_payload_factory(sqla_model: TableProtocol) -> CreatePayloadType: 65 | """CreateAccountPayload""" 66 | from nebulo.gql.convert.table import table_factory 67 | 68 | relevant_type_name = Config.table_type_name_mapper(sqla_model) 69 | relevant_attr_name = Config.table_name_mapper(sqla_model) 70 | result_name = f"Create{relevant_type_name}Payload" 71 | 72 | attrs = { 73 | "clientMutationId": Field(String, resolve=default_resolver), 74 | "nodeId": ID, 75 | relevant_attr_name: Field( 76 | NonNull(table_factory(sqla_model)), 77 | description=f"The {relevant_type_name} that was created by this mutation.", 78 | resolve=default_resolver, 79 | ), 80 | } 81 | 82 | return CreatePayloadType( 83 | result_name, attrs, description=f"The output of our create {relevant_type_name} mutation", sqla_model=sqla_model 84 | ) 85 | -------------------------------------------------------------------------------- /examples/sqla_interop/app.py: -------------------------------------------------------------------------------- 1 | import uvicorn 2 | from nebulo.gql.sqla_to_gql import sqla_models_to_graphql_schema 3 | from nebulo.server.exception import http_exception 4 | from nebulo.server.routes import get_graphiql_route, get_graphql_route 5 | from sqlalchemy import Column, DateTime, ForeignKey, Integer, Text, create_engine 6 | from sqlalchemy import text as sql_text 7 | from sqlalchemy.ext.asyncio import create_async_engine 8 | from sqlalchemy.ext.declarative import declarative_base 9 | from sqlalchemy.orm import relationship 10 | from starlette.applications import Starlette 11 | from starlette.exceptions import HTTPException 12 | 13 | # Config 14 | DATABASE_URI = "postgresql://app_user:app_password@localhost:5522/nebulo_example" 15 | 16 | 17 | ##################### 18 | # SQLAlchemy Models # 19 | ##################### 20 | 21 | Base = declarative_base() 22 | 23 | 24 | class Author(Base): 25 | __tablename__ = "author" 26 | 27 | id = Column(Integer, primary_key=True, comment="@exclude create, update") 28 | name = Column(Text, nullable=False) 29 | created_at = Column( 30 | DateTime, 31 | nullable=False, 32 | server_default=sql_text("now()"), 33 | comment="@exclude create, update", 34 | ) 35 | 36 | books = relationship("Book", uselist=True, backref="Author") 37 | 38 | 39 | class Book(Base): 40 | __tablename__ = "book" 41 | 42 | id = Column(Integer, primary_key=True, comment="@exclude create, update") 43 | title = Column(Text, nullable=False) 44 | author_id = Column(Integer, ForeignKey("author.id"), nullable=False) 45 | created_at = Column( 46 | DateTime, 47 | nullable=False, 48 | default=sql_text("now()"), 49 | comment="@exclude create, update", 50 | ) 51 | 52 | # backref: Author 53 | 54 | 55 | ################################# 56 | # Starlette Application Factory # 57 | ################################# 58 | 59 | 60 | def create_app(connection_str, sqla_models) -> Starlette: 61 | """Create the Starlette app""" 62 | 63 | async_connection = connection_str.replace("postgresql://", "postgresql+asyncpg://") 64 | engine = create_async_engine(async_connection) 65 | 66 | # Convert sqla models to graphql schema 67 | gql_schema = sqla_models_to_graphql_schema(sqla_models) 68 | 69 | # Build the Starlette GraphQL Route 70 | graphql_route = get_graphql_route(gql_schema=gql_schema, engine=engine) 71 | 72 | # Build the Starlette GraphiQL Route and StaticFiles 73 | graphiql_route = get_graphiql_route() 74 | 75 | # Instantiate the Starlette app 76 | _app = Starlette( 77 | routes=[graphql_route, graphiql_route], 78 | exception_handlers={HTTPException: http_exception}, 79 | on_startup=[engine.connect], 80 | on_shutdown=[engine.dispose], 81 | ) 82 | 83 | return _app 84 | 85 | 86 | # Instantiate the app 87 | APP = create_app(connection_str=DATABASE_URI, sqla_models=[Author, Book]) 88 | 89 | 90 | if __name__ == "__main__": 91 | 92 | # Create Tables 93 | with create_engine(DATABASE_URI).connect() as sqla_engine: 94 | Base.metadata.create_all(bind=sqla_engine) 95 | 96 | uvicorn.run( 97 | "app:APP", 98 | host="0.0.0.0", 99 | port=5084, 100 | log_level="info", 101 | reload=False, 102 | ) 103 | -------------------------------------------------------------------------------- /src/nebulo/gql/convert/table.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name,unsubscriptable-object 2 | from __future__ import annotations 3 | 4 | from functools import lru_cache 5 | 6 | from nebulo.config import Config 7 | from nebulo.gql.alias import Argument, Field, NonNull, TableType 8 | from nebulo.gql.convert.column import convert_column 9 | from nebulo.gql.relay.node_interface import ID, NodeInterface 10 | from nebulo.gql.resolve.resolvers.default import default_resolver 11 | from nebulo.sql.inspect import get_columns, get_foreign_key_constraint_from_relationship, get_relationships, is_nullable 12 | from nebulo.sql.table_base import TableProtocol 13 | from sqlalchemy.orm import interfaces 14 | 15 | 16 | def table_field_factory(sqla_model: TableProtocol, resolver) -> Field: 17 | relevant_type_name = Config.table_type_name_mapper(sqla_model) 18 | node = table_factory(sqla_model) 19 | return Field( 20 | node, 21 | args={"nodeId": Argument(NonNull(ID))}, 22 | resolve=resolver, 23 | description=f"Reads a single {relevant_type_name} using its globally unique ID", 24 | ) 25 | 26 | 27 | @lru_cache() 28 | def table_factory(sqla_model: TableProtocol) -> TableType: 29 | """ 30 | Reflects a SQLAlchemy table into a graphql-core GraphQLObjectType 31 | 32 | Parameters 33 | ---------- 34 | sqla_model 35 | A SQLAlchemy ORM Table 36 | 37 | """ 38 | from .connection import connection_field_factory 39 | 40 | name = Config.table_type_name_mapper(sqla_model) 41 | 42 | def build_attrs(): 43 | attrs = {} 44 | 45 | # Override id to relay standard 46 | attrs["nodeId"] = Field(NonNull(ID), resolve=default_resolver) 47 | 48 | for column in get_columns(sqla_model): 49 | if not Config.exclude_read(column): 50 | key = Config.column_name_mapper(column) 51 | attrs[key] = convert_column(column) 52 | 53 | for relationship in get_relationships(sqla_model): 54 | direction = relationship.direction 55 | to_sqla_model = relationship.mapper.class_ 56 | 57 | fkey_constraint = get_foreign_key_constraint_from_relationship(relationship) 58 | if fkey_constraint and Config.exclude_read(fkey_constraint): 59 | continue 60 | 61 | relationship_is_nullable = is_nullable(relationship) 62 | 63 | # Name of the attribute on the model 64 | attr_key = Config.relationship_name_mapper(relationship) 65 | 66 | # If this model has 1 counterpart, do not use a list 67 | if direction == interfaces.MANYTOONE: 68 | _type = table_factory(to_sqla_model) 69 | _type = NonNull(_type) if not relationship_is_nullable else _type 70 | attrs[attr_key] = Field(_type, resolve=default_resolver) 71 | 72 | # Otherwise, set it up as a connection 73 | elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): 74 | 75 | connection_field = connection_field_factory( 76 | to_sqla_model, resolver=default_resolver, not_null=relationship_is_nullable 77 | ) 78 | attrs[attr_key] = connection_field 79 | 80 | return attrs 81 | 82 | return_type = TableType( 83 | name=name, fields=build_attrs, interfaces=[NodeInterface], description="", sqla_model=sqla_model 84 | ) 85 | 86 | return return_type 87 | -------------------------------------------------------------------------------- /src/nebulo/server/routes/graphql.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Awaitable, Dict, Optional 2 | 3 | from graphql import graphql as graphql_exec 4 | from nebulo.gql.alias import Schema 5 | from nebulo.server.jwt import get_jwt_claims_handler 6 | from sqlalchemy.ext.asyncio import AsyncEngine 7 | from starlette.exceptions import HTTPException 8 | from starlette.requests import Request 9 | from starlette.responses import JSONResponse 10 | from starlette.routing import Route 11 | 12 | __all__ = ["get_graphql_route"] 13 | 14 | 15 | def get_graphql_route( 16 | gql_schema: Schema, 17 | engine: AsyncEngine, 18 | path: str = "/", 19 | jwt_secret: Optional[str] = None, 20 | default_role: Optional[str] = None, 21 | name: Optional[str] = None, 22 | ) -> Route: 23 | """Create a Starlette Route to serve GraphQL requests 24 | 25 | **Parameters** 26 | 27 | * **schema**: _Schema_ = A GraphQL-core schema 28 | * **engine**: _AsyncEngine_ = Database connection object for communicating with PostgreSQL 29 | * **path**: _str_ = URL path to serve GraphQL from, e.g. '/' 30 | * **jwt_secret**: _str_ = secret key used to encrypt JWT contents 31 | * **default_role**: _str_ = Default SQL role to use when serving unauthenticated requests 32 | * **name**: _str_ = Name of the GraphQL serving Starlette route 33 | """ 34 | 35 | get_jwt_claims = get_jwt_claims_handler(jwt_secret) 36 | 37 | async def graphql_endpoint(request: Request) -> Awaitable[JSONResponse]: 38 | 39 | query = await get_query(request) 40 | variables = await get_variables(request) 41 | jwt_claims = await get_jwt_claims(request) 42 | request_context = { 43 | "request": request, 44 | "engine": engine, 45 | "query": query, 46 | "variables": variables, 47 | "jwt_claims": jwt_claims, 48 | "default_role": default_role, 49 | } 50 | result = await graphql_exec( 51 | schema=gql_schema, 52 | source=query, 53 | context_value=request_context, 54 | variable_values=variables, 55 | ) 56 | errors = result.errors 57 | result_dict = { 58 | "data": result.data, 59 | "errors": [error.formatted for error in errors or []], 60 | } 61 | return JSONResponse(result_dict) 62 | 63 | graphql_route = Route(path=path, endpoint=graphql_endpoint, methods=["POST"], name=name) 64 | 65 | return graphql_route 66 | 67 | 68 | async def get_query(request: Request) -> Awaitable[str]: 69 | """Retrieve the GraphQL query from the Starlette Request""" 70 | 71 | content_type = request.headers.get("content-type", "") 72 | if content_type == "application/graphql": 73 | return await request.body() 74 | if content_type == "application/json": 75 | return (await request.json())["query"] 76 | raise HTTPException(400, "content-type header must be set") 77 | 78 | 79 | async def get_variables(request) -> Awaitable[Dict[str, Any]]: 80 | """Retrieve the GraphQL variables from the Starlette Request""" 81 | 82 | content_type = request.headers.get("content-type", "") 83 | if content_type == "application/graphql": 84 | raise NotImplementedError("Getting variables from graphql content type not known") 85 | 86 | if content_type == "application/json": 87 | return (await request.json()).get("variables", {}) 88 | raise HTTPException(400, "content-type header must be set") 89 | -------------------------------------------------------------------------------- /src/nebulo/gql/resolve/transpile/mutation_builder.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | from __future__ import annotations 3 | 4 | from nebulo.config import Config 5 | from nebulo.gql.alias import CreatePayloadType, DeletePayloadType, UpdatePayloadType 6 | from nebulo.gql.parse_info import ASTNode 7 | from nebulo.gql.relay.node_interface import to_node_id_sql 8 | from nebulo.gql.resolve.transpile.query_builder import field_name_to_column 9 | from nebulo.sql.inspect import get_primary_key_columns 10 | 11 | 12 | def build_mutation(tree: ASTNode): 13 | """Dispatch for Mutation Types 14 | 15 | Returns an executable sqlalchemy statment and a result post-processor 16 | """ 17 | 18 | if isinstance(tree.return_type, CreatePayloadType): 19 | return build_insert(tree) 20 | elif isinstance(tree.return_type, UpdatePayloadType): 21 | return build_update(tree) 22 | elif isinstance(tree.return_type, DeletePayloadType): 23 | return build_delete(tree) 24 | else: 25 | raise Exception("Unknown mutation type") 26 | 27 | 28 | def build_insert(tree: ASTNode): 29 | return_type = tree.return_type 30 | # Find the table type 31 | return_sqla_model = return_type.sqla_model 32 | table_input_arg_name = Config.table_name_mapper(return_sqla_model) 33 | input_values = tree.args["input"][table_input_arg_name] 34 | col_name_to_value = {} 35 | for arg_name, arg_value in input_values.items(): 36 | col = field_name_to_column(return_sqla_model, arg_name) 37 | col_name_to_value[col.name] = arg_value 38 | 39 | core_table = return_sqla_model.__table__ 40 | 41 | query = ( 42 | core_table.insert() 43 | .values(**col_name_to_value) 44 | .returning(to_node_id_sql(return_sqla_model, core_table).label("nodeId")) 45 | ) 46 | return query 47 | 48 | 49 | def build_update(tree: ASTNode): 50 | return_type = tree.return_type 51 | sqla_model = return_type.sqla_model 52 | 53 | # Where Clause 54 | pkey_cols = get_primary_key_columns(sqla_model) 55 | node_id = tree.args["input"]["nodeId"] 56 | pkey_clause = [col == node_id.values[str(col.name)] for col in pkey_cols] 57 | 58 | # Find the table type 59 | return_sqla_model = return_type.sqla_model 60 | table_input_arg_name = Config.table_name_mapper(return_sqla_model) 61 | input_values = tree.args["input"][table_input_arg_name] 62 | col_name_to_value = {} 63 | for arg_name, arg_value in input_values.items(): 64 | col = field_name_to_column(return_sqla_model, arg_name) 65 | col_name_to_value[col.name] = arg_value 66 | 67 | core_table = return_sqla_model.__table__ 68 | 69 | query = ( 70 | core_table.update() 71 | .where(*pkey_clause) 72 | .values(**col_name_to_value) 73 | .returning(to_node_id_sql(return_sqla_model, core_table).label("nodeId")) 74 | ) 75 | return query 76 | 77 | 78 | def build_delete(tree: ASTNode): 79 | return_type = tree.return_type 80 | sqla_model = return_type.sqla_model 81 | 82 | # Where Clause 83 | pkey_cols = get_primary_key_columns(sqla_model) 84 | node_id = tree.args["input"]["nodeId"] 85 | pkey_clause = [col == node_id.values[str(col.name)] for col in pkey_cols] 86 | 87 | # Find the table type 88 | return_sqla_model = return_type.sqla_model 89 | core_table = return_sqla_model.__table__ 90 | 91 | query = ( 92 | core_table.delete().where(*pkey_clause).returning(to_node_id_sql(return_sqla_model, core_table).label("nodeId")) 93 | ) 94 | return query 95 | -------------------------------------------------------------------------------- /src/nebulo/gql/convert/update.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import typing as t 4 | from functools import lru_cache 5 | 6 | from nebulo.config import Config 7 | from nebulo.gql.alias import ( 8 | Field, 9 | InputObjectType, 10 | NonNull, 11 | ObjectType, 12 | String, 13 | TableInputType, 14 | UpdateInputType, 15 | UpdatePayloadType, 16 | ) 17 | from nebulo.gql.convert.column import convert_column_to_input 18 | from nebulo.gql.relay.node_interface import ID 19 | from nebulo.gql.resolve.resolvers.default import default_resolver 20 | from nebulo.sql.inspect import get_columns 21 | from nebulo.sql.table_base import TableProtocol 22 | 23 | """ 24 | updateAccount(input: UpdateAccountInput!): 25 | UpdateAccountPayload 26 | """ 27 | 28 | 29 | @lru_cache() 30 | def update_entrypoint_factory(sqla_model: TableProtocol, resolver) -> t.Dict[str, Field]: 31 | """updateAccount""" 32 | relevant_type_name = Config.table_type_name_mapper(sqla_model) 33 | name = f"update{relevant_type_name}" 34 | args = {"input": NonNull(update_input_type_factory(sqla_model))} 35 | payload = update_payload_factory(sqla_model) 36 | return { 37 | name: Field( 38 | payload, 39 | args=args, 40 | resolve=resolver, 41 | description=f"Updates a single {relevant_type_name} using its globally unique id and a patch.", 42 | ) 43 | } 44 | 45 | 46 | @lru_cache() 47 | def update_input_type_factory(sqla_model: TableProtocol) -> InputObjectType: 48 | """UpdateAccountInput!""" 49 | relevant_type_name = Config.table_type_name_mapper(sqla_model) 50 | result_name = f"Update{relevant_type_name}Input" 51 | 52 | input_object_name = Config.table_name_mapper(sqla_model) 53 | 54 | attrs = { 55 | "nodeId": NonNull(ID), 56 | "clientMutationId": String, 57 | input_object_name: NonNull(patch_type_factory(sqla_model)), 58 | } 59 | return UpdateInputType(result_name, attrs, description=f"All input for the create {relevant_type_name} mutation.") 60 | 61 | 62 | def patch_type_factory(sqla_model: TableProtocol) -> InputObjectType: 63 | """AccountPatch""" 64 | relevant_type_name = Config.table_type_name_mapper(sqla_model) 65 | result_name = f"{relevant_type_name}Patch" 66 | 67 | attrs = {} 68 | for column in get_columns(sqla_model): 69 | if not Config.exclude_update(column): 70 | field_key = Config.column_name_mapper(column) 71 | column_field = convert_column_to_input(column) 72 | # TODO Unwrap not null here 73 | attrs[field_key] = column_field 74 | return TableInputType(result_name, attrs, description=f"An input for mutations affecting {relevant_type_name}.") 75 | 76 | 77 | @lru_cache() 78 | def update_payload_factory(sqla_model: TableProtocol) -> InputObjectType: 79 | """UpdateAccountPayload""" 80 | from nebulo.gql.convert.table import table_factory 81 | 82 | relevant_type_name = Config.table_type_name_mapper(sqla_model) 83 | relevant_attr_name = Config.table_name_mapper(sqla_model) 84 | result_name = f"Update{relevant_type_name}Payload" 85 | 86 | attrs = { 87 | "clientMutationId": Field(String, resolve=default_resolver), 88 | "nodeId": ID, 89 | relevant_attr_name: Field( 90 | table_factory(sqla_model), 91 | resolve=default_resolver, 92 | description=f"The {relevant_type_name} that was created by this mutation.", 93 | ), 94 | } 95 | 96 | return UpdatePayloadType( 97 | result_name, attrs, description=f"The output of our update {relevant_type_name} mutation", sqla_model=sqla_model 98 | ) 99 | -------------------------------------------------------------------------------- /src/test/conftest.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=redefined-outer-name 2 | from __future__ import annotations 3 | 4 | import importlib 5 | from typing import Callable, Optional 6 | 7 | import pytest 8 | from nebulo.gql.sqla_to_gql import sqla_models_to_graphql_schema 9 | from nebulo.server.starlette import create_app 10 | from nebulo.sql import table_base 11 | from nebulo.sql.reflection.constraint_comments import reflect_all_constraint_comments 12 | from nebulo.sql.reflection.manager import reflect_sqla_models 13 | from sqlalchemy import create_engine 14 | from sqlalchemy.orm import scoped_session, sessionmaker 15 | from starlette.applications import Starlette 16 | from starlette.testclient import TestClient 17 | 18 | SQL_DOWN = """ 19 | DROP SCHEMA public CASCADE; 20 | CREATE SCHEMA public; 21 | GRANT ALL ON SCHEMA public TO nebulo_user; 22 | """ 23 | 24 | 25 | @pytest.fixture(scope="function", autouse=True) 26 | def clear_caches(): 27 | reflect_all_constraint_comments.cache_clear() 28 | 29 | 30 | @pytest.fixture(scope="session") 31 | def connection_str(): 32 | """ SQLAlchemy connection string to test database """ 33 | return "postgresql://nebulo_user:password@localhost:4442/nebulo_db" 34 | 35 | 36 | @pytest.fixture(scope="session") 37 | def engine(connection_str: str): 38 | """ SQLAlchemy engine """ 39 | _engine = create_engine(connection_str, echo=False) 40 | # Make sure the schema is clean 41 | _engine.execute(SQL_DOWN) 42 | yield _engine 43 | _engine.execute(SQL_DOWN) 44 | _engine.dispose() 45 | 46 | 47 | @pytest.fixture(scope="session") 48 | def session_maker(engine): 49 | smake = sessionmaker(bind=engine) 50 | _session = scoped_session(smake) 51 | yield _session 52 | 53 | 54 | @pytest.fixture 55 | def session(session_maker): 56 | _session = session_maker 57 | _session.execute(SQL_DOWN) 58 | _session.commit() 59 | yield _session 60 | _session.rollback() 61 | _session.execute(SQL_DOWN) 62 | _session.commit() 63 | _session.close() 64 | 65 | 66 | @pytest.fixture 67 | def schema_builder(session, engine): 68 | """Return a function that accepts a sql string 69 | and returns graphql schema""" 70 | 71 | def build(sql: str): 72 | session.execute(sql) 73 | session.commit() 74 | tables, functions = reflect_sqla_models(engine, schema="public") 75 | schema = sqla_models_to_graphql_schema(tables, functions) 76 | return schema 77 | 78 | yield build 79 | 80 | 81 | @pytest.fixture 82 | def app_builder(connection_str, session) -> Callable[[str, Optional[str], Optional[str]], Starlette]: 83 | def build(sql: str, jwt_identifier: Optional[str] = None, jwt_secret: Optional[str] = None) -> Starlette: 84 | session.execute(sql) 85 | session.commit() 86 | # Create the schema 87 | app = create_app(connection_str, jwt_identifier=jwt_identifier, jwt_secret=jwt_secret) 88 | return app 89 | 90 | return build 91 | 92 | 93 | @pytest.fixture 94 | def client_builder( 95 | app_builder: Callable[[str, Optional[str], Optional[str]], Starlette] 96 | ) -> Callable[[str, Optional[str], Optional[str]], TestClient]: 97 | # NOTE: Client must be used as a context manager for on_startup and on_shutdown to execute 98 | # e.g. connect to the database 99 | 100 | def build(sql: str, jwt_identifier: Optional[str] = None, jwt_secret: Optional[str] = None) -> TestClient: 101 | importlib.reload(table_base) 102 | app = app_builder(sql, jwt_identifier, jwt_secret) 103 | client = TestClient(app) 104 | return client 105 | 106 | return build 107 | -------------------------------------------------------------------------------- /src/nebulo/gql/sqla_to_gql.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import typing 4 | 5 | from nebulo.config import Config 6 | from nebulo.gql.alias import ObjectType, Schema 7 | from nebulo.gql.convert.connection import connection_field_factory 8 | from nebulo.gql.convert.create import create_entrypoint_factory 9 | from nebulo.gql.convert.delete import delete_entrypoint_factory 10 | from nebulo.gql.convert.function import ( 11 | immutable_function_entrypoint_factory, 12 | is_jwt_function, 13 | mutable_function_entrypoint_factory, 14 | ) 15 | from nebulo.gql.convert.table import table_field_factory 16 | from nebulo.gql.convert.update import update_entrypoint_factory 17 | from nebulo.gql.resolve.resolvers.asynchronous import async_resolver as resolver 18 | from nebulo.sql.inspect import get_table_name 19 | from nebulo.sql.reflection.function import SQLFunction 20 | from nebulo.text_utils import snake_to_camel, to_plural 21 | 22 | __all__ = ["sqla_models_to_graphql_schema"] 23 | 24 | 25 | def sqla_models_to_graphql_schema( 26 | sqla_models, 27 | sql_functions: typing.Optional[typing.List[SQLFunction]] = None, 28 | jwt_identifier: typing.Optional[str] = None, 29 | jwt_secret: typing.Optional[str] = None, 30 | ) -> Schema: 31 | """Creates a GraphQL Schema from SQLA Models 32 | 33 | **Parameters** 34 | 35 | * **sqla_models**: _List[Type[SQLAModel]]_ = List of SQLAlchemy models to include in the GraphQL schema 36 | * **jwt_identifier**: _str_ = qualified path of SQL composite type to use encode as a JWT e.g. 'public.jwt' 37 | * **jwt_secret**: _str_ = Secret key used to encrypt JWT contents 38 | * **sql_functions** = **NOT PUBLIC API** 39 | """ 40 | 41 | if sql_functions is None: 42 | sql_functions = [] 43 | 44 | query_fields = {} 45 | mutation_fields = {} 46 | 47 | # Tables 48 | for sqla_model in sqla_models: 49 | 50 | if not Config.exclude_read_one(sqla_model): 51 | # e.g. account(nodeId: NodeID) 52 | single_name = snake_to_camel(get_table_name(sqla_model), upper=False) 53 | query_fields[single_name] = table_field_factory(sqla_model, resolver) 54 | 55 | if not Config.exclude_read_all(sqla_model): 56 | # e.g. allAccounts(first: Int, last: Int ....) 57 | connection_name = "all" + snake_to_camel(to_plural(get_table_name(sqla_model)), upper=True) 58 | query_fields[connection_name] = connection_field_factory(sqla_model, resolver) 59 | 60 | if not Config.exclude_create(sqla_model): 61 | # e.g. createAccount(input: CreateAccountInput) 62 | mutation_fields.update(create_entrypoint_factory(sqla_model, resolver=resolver)) 63 | 64 | if not Config.exclude_update(sqla_model): 65 | # e.g. updateAccount(input: UpdateAccountInput) 66 | mutation_fields.update(update_entrypoint_factory(sqla_model, resolver=resolver)) 67 | 68 | if not Config.exclude_delete(sqla_model): 69 | # e.g. deleteAccount(input: DeleteAccountInput) 70 | mutation_fields.update(delete_entrypoint_factory(sqla_model, resolver=resolver)) 71 | # Functions 72 | for sql_function in sql_functions: 73 | if is_jwt_function(sql_function, jwt_identifier): 74 | mutation_fields.update( 75 | mutable_function_entrypoint_factory(sql_function=sql_function, resolver=resolver, jwt_secret=jwt_secret) 76 | ) 77 | else: 78 | 79 | # Immutable functions are queries 80 | if sql_function.is_immutable: 81 | query_fields.update(immutable_function_entrypoint_factory(sql_function=sql_function, resolver=resolver)) 82 | 83 | # Mutable functions are mutations 84 | else: 85 | mutation_fields.update( 86 | mutable_function_entrypoint_factory(sql_function=sql_function, resolver=resolver) 87 | ) 88 | 89 | schema_kwargs = { 90 | "query": ObjectType(name="Query", fields=query_fields), 91 | "mutation": ObjectType(name="Mutation", fields=mutation_fields), 92 | } 93 | return Schema(**{k: v for k, v in schema_kwargs.items() if v.fields}) 94 | -------------------------------------------------------------------------------- /examples/row_level_security/setup.sql: -------------------------------------------------------------------------------- 1 | /************************************************* 2 | Create Roles for Authenticated and Anonymous Users 3 | **************************************************/ 4 | 5 | -- Authenticated users 6 | create role api_user; 7 | grant usage on schema public to api_user; 8 | alter default privileges in schema public grant all on tables to api_user; 9 | alter default privileges in schema public grant all on sequences to api_user; 10 | 11 | -- Anonymous users 12 | create role anon_api_user; 13 | grant usage on schema public to anon_api_user; 14 | alter default privileges in schema public grant all on sequences to anon_api_user; 15 | 16 | /************************************** 17 | Create "account" and "blog_post" tables 18 | ***************************************/ 19 | create table account ( 20 | id bigserial primary key, 21 | username text not null unique, 22 | password_hash text not null, 23 | created_at timestamp without time zone default (now() at time zone 'utc') 24 | ); 25 | -- Do not expose the "password_hash" column 26 | comment on column account.password_hash is E'@exclude create, update, delete, read'; 27 | -- Hide createAccount mutation, so we can provide our own to handle "password_hash" 28 | comment on table account is E'@exclude create'; 29 | 30 | -- Allow the anonymous user to create and read and account 31 | grant insert on table public.account to anon_api_user; 32 | grant select on table public.account to anon_api_user; 33 | 34 | 35 | create table blog_post( 36 | id bigserial primary key, 37 | account_id bigint not null references account(id), 38 | title text not null, 39 | body text not null, 40 | created_at timestamp without time zone default (now() at time zone 'utc') 41 | ); 42 | 43 | -- Enable extension for safe hashing and checking of passwords 44 | create extension pgcrypto; 45 | 46 | -- create our own createAccount mutation to create accounts and hash the password 47 | create or replace function create_account( 48 | username text, 49 | password text 50 | ) returns account 51 | as $$ 52 | declare 53 | acct account; 54 | begin 55 | with new_account as ( 56 | insert into account(username, password_hash) values 57 | -- Hash the password 58 | (username, crypt(password, gen_salt('bf', 8))) 59 | returning * 60 | ) 61 | select * into acct from new_account; 62 | 63 | return acct; 64 | end; 65 | $$ language plpgsql; 66 | 67 | 68 | /***************** 69 | JWT Authentication 70 | ******************/ 71 | create type jwt as ( 72 | account_id bigint, 73 | username text, 74 | /* special field to sets the postgres 75 | role for the current transaction to 76 | anon_api_user or api_user */ 77 | role text, 78 | exp integer 79 | ); 80 | 81 | 82 | create or replace function authenticate( 83 | username text, 84 | password text 85 | ) returns jwt 86 | as $$ 87 | declare 88 | acct account; 89 | begin 90 | select a.* into acct 91 | from account as a 92 | where a.username = authenticate.username; 93 | 94 | if acct.password_hash = crypt(password, acct.password_hash) then 95 | return ( 96 | acct.id, 97 | acct.username, 98 | 'api_user', 99 | extract(epoch from now() + interval '7 days') 100 | )::jwt; 101 | 102 | else return null; 103 | 104 | end if; 105 | end; 106 | $$ language plpgsql strict security definer; 107 | 108 | 109 | /************* 110 | Access Control 111 | **************/ 112 | -- Opt in to row level security for 113 | alter table account enable row level security; 114 | alter table blog_post enable row level security; 115 | 116 | -- An anonymous user may create an account 117 | create policy rls_account_anon_insert on account for insert to anon_api_user with check (true); 118 | 119 | -- Accounts can be seen by anyone 120 | create policy rls_account_select on account for select using (true); 121 | 122 | -- Accounts may be edited by their owner 123 | create policy rls_account_all on account using (id = nullif(current_setting('jwt.claims.account_id', true), '')::bigint); 124 | 125 | -- Blog posts can be seen by anyone 126 | create policy rls_blog_post_select on blog_post for select using (true); 127 | 128 | -- Blog posts are editable by their account owner 129 | create policy rls_blog_post_mod on blog_post using (account_id = nullif(current_setting('jwt.claims.account_id', true), '')::bigint); 130 | -------------------------------------------------------------------------------- /src/nebulo/sql/inspect.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=unsubscriptable-object, invalid-name 2 | from __future__ import annotations 3 | 4 | from functools import lru_cache 5 | from typing import Callable, List, Optional, Union 6 | 7 | from nebulo.sql.table_base import TableProtocol 8 | from sqlalchemy import Column, ForeignKeyConstraint, Table 9 | from sqlalchemy import inspect as sql_inspect 10 | from sqlalchemy.orm import Mapper, RelationshipProperty 11 | from sqlalchemy.sql.schema import Constraint 12 | 13 | 14 | @lru_cache() 15 | def get_table_name(entity: Union[Table, TableProtocol]) -> str: 16 | """Name of the table""" 17 | return str(to_table(entity).name) 18 | 19 | 20 | @lru_cache() 21 | def get_relationships(sqla_model: TableProtocol) -> List[RelationshipProperty]: 22 | """Relationships with other tables""" 23 | return list(sql_inspect(sqla_model).relationships) 24 | 25 | 26 | @lru_cache() 27 | def get_primary_key_columns(sqla_model: TableProtocol) -> List[Column]: 28 | """Primary key""" 29 | return [x for x in sqla_model.__table__.primary_key.columns] 30 | 31 | 32 | @lru_cache() 33 | def get_columns(sqla_model: TableProtocol) -> List[Column]: 34 | """Columns on the table""" 35 | return [x for x in sqla_model.__table__.columns] 36 | 37 | 38 | @lru_cache() 39 | def get_constraints(entity: Union[TableProtocol, Table]) -> List[Constraint]: 40 | """Retrieve constraints from a table""" 41 | return list(to_table(entity).constraints) 42 | 43 | 44 | @lru_cache() 45 | def is_nullable(relationship: RelationshipProperty) -> bool: 46 | """Checks if a sqlalchemy orm relationship is nullable""" 47 | for local_col, remote_col in relationship.local_remote_pairs: 48 | if local_col.nullable or remote_col.nullable: 49 | return True 50 | return False 51 | 52 | 53 | @lru_cache() 54 | def get_comment(entity: Union[Table, TableProtocol, Column, Constraint]) -> str: 55 | """Get comment on entity""" 56 | if isinstance(entity, TableProtocol): 57 | return to_table(entity).comment or "" 58 | elif isinstance(entity, Table): 59 | return entity.comment or "" 60 | elif isinstance(entity, Column): 61 | return entity.comment or "" 62 | elif isinstance(entity, Constraint): 63 | if hasattr(entity, "info"): 64 | return getattr(entity, "info").get("comment") or "" 65 | return "" 66 | raise ValueError("invalid entity passed to get_comment") 67 | 68 | 69 | @lru_cache() 70 | def get_foreign_key_constraints(entity: Union[TableProtocol, Table]) -> List[ForeignKeyConstraint]: 71 | """Retrieve all foreign keys associated with the table""" 72 | constraints = get_constraints(entity) 73 | fkeys = [x for x in constraints if isinstance(x, ForeignKeyConstraint)] 74 | return fkeys 75 | 76 | 77 | @lru_cache() 78 | def get_foreign_key_constraint_from_relationship(relationship: RelationshipProperty) -> Optional[ForeignKeyConstraint]: 79 | """Get the ForeignKeyConstraint that backs the input RelationshipProperty 80 | 81 | Note: the resolution method checks that the columns associated with the relationship 82 | match the columns associated with the foreign key constraint. If two identical 83 | foreign keys exist (they never should), the behavior is undefined 84 | """ 85 | local_cols = list(relationship.local_columns) 86 | remote_cols = list(relationship.remote_side) 87 | 88 | # All columns associated with the relationship 89 | relationship_cols = set([*local_cols, *remote_cols]) 90 | 91 | local_table: Table = local_cols[0].table 92 | remote_table: Table = remote_cols[0].table 93 | 94 | possible_fkeys = [ 95 | *get_foreign_key_constraints(local_table), 96 | *get_foreign_key_constraints(remote_table), 97 | ] 98 | 99 | for fkey in possible_fkeys: 100 | 101 | # All local and remote columns associated with the primary key 102 | fkey_involved_cols = set([*list(fkey.columns), *[x.column for x in fkey.elements]]) 103 | 104 | if relationship_cols == fkey_involved_cols: 105 | return fkey 106 | return None 107 | 108 | 109 | @lru_cache() 110 | def to_table(entity: Union[Table, TableProtocol, Mapper, Callable[[], Union[TableProtocol, Table]]]) -> Table: 111 | """Coerces Table and ORM Table to Table""" 112 | if isinstance(entity, Table): 113 | return entity 114 | elif isinstance(entity, Mapper): 115 | return to_table(entity.class_) 116 | elif isinstance(entity, TableProtocol): 117 | return entity.__table__ 118 | elif callable(entity): 119 | return to_table(entity()) 120 | raise NotImplementedError("invalid input type") 121 | -------------------------------------------------------------------------------- /src/nebulo/sql/reflection/types.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name 2 | from typing import Dict, Tuple 3 | 4 | from flupy import flu 5 | from nebulo.sql.composite import composite_type_factory 6 | from nebulo.text_utils import snake_to_camel 7 | from sqlalchemy import Column 8 | from sqlalchemy import text as sql_text 9 | from sqlalchemy.sql import sqltypes 10 | from sqlalchemy.sql.type_api import TypeEngine 11 | 12 | 13 | def reflect_composites(engine, schema, type_map) -> Dict[Tuple[str, str], TypeEngine]: 14 | """Get a list of functions available in the database""" 15 | 16 | sql = sql_text( 17 | """ 18 | WITH types AS ( 19 | SELECT 20 | n.nspname, 21 | pg_catalog.format_type ( t.oid, NULL ) AS obj_name, 22 | CASE 23 | WHEN t.typrelid != 0 THEN CAST ( 'tuple' AS pg_catalog.text ) 24 | WHEN t.typlen < 0 THEN CAST ( 'var' AS pg_catalog.text ) 25 | ELSE CAST ( t.typlen AS pg_catalog.text ) 26 | END AS obj_type, 27 | coalesce ( pg_catalog.obj_description ( t.oid, 'pg_type' ), '' ) AS description 28 | FROM 29 | pg_catalog.pg_type t 30 | JOIN pg_catalog.pg_namespace n 31 | ON n.oid = t.typnamespace 32 | WHERE ( t.typrelid = 0 33 | OR ( SELECT c.relkind = 'c' 34 | FROM pg_catalog.pg_class c 35 | WHERE c.oid = t.typrelid ) ) 36 | AND NOT EXISTS ( 37 | SELECT 1 38 | FROM pg_catalog.pg_type el 39 | WHERE el.oid = t.typelem 40 | AND el.typarray = t.oid ) 41 | AND n.nspname <> 'pg_catalog' 42 | AND n.nspname <> 'information_schema' 43 | AND n.nspname !~ '^pg_toast' 44 | ), 45 | cols AS ( 46 | SELECT n.nspname::text AS schema_name, 47 | /* 48 | When the schema is public, pg_catalog.format_type does not 49 | include a schema prefix, but when its any other schema, it does. 50 | This function removes the schema prefix if it exists for 51 | consistency 52 | */ 53 | ( 54 | string_to_array( 55 | pg_catalog.format_type ( t.oid, NULL ), 56 | '.' 57 | ) 58 | )[array_upper( 59 | string_to_array( 60 | pg_catalog.format_type ( t.oid, NULL ), 61 | '.' 62 | ), 63 | 1) 64 | ] as obj_name, 65 | a.attname::text AS column_name, 66 | pg_catalog.format_type ( a.atttypid, a.atttypmod ) AS data_type, 67 | a.attnotnull AS is_required, 68 | a.attnum AS ordinal_position, 69 | pg_catalog.col_description ( a.attrelid, a.attnum ) AS description 70 | FROM pg_catalog.pg_attribute a 71 | JOIN pg_catalog.pg_type t 72 | ON a.attrelid = t.typrelid 73 | JOIN pg_catalog.pg_namespace n 74 | ON ( n.oid = t.typnamespace ) 75 | JOIN types 76 | ON ( types.nspname = n.nspname 77 | AND types.obj_name = pg_catalog.format_type ( t.oid, NULL ) ) 78 | WHERE a.attnum > 0 79 | AND NOT a.attisdropped 80 | ) 81 | SELECT 82 | cols.schema_name, 83 | cols.obj_name, 84 | cols.column_name, 85 | cols.data_type, 86 | cols.ordinal_position, 87 | cols.is_required, 88 | coalesce ( cols.description, '' ) AS description 89 | FROM 90 | cols 91 | WHERE 92 | cols.schema_name = :schema 93 | ORDER BY 94 | cols.schema_name, 95 | cols.obj_name, 96 | cols.ordinal_position 97 | """ 98 | ) 99 | rows = engine.execute(sql, schema=schema).fetchall() 100 | 101 | composites = {} 102 | 103 | for _, composite_rows in flu(rows).group_by(lambda x: x[0] + "." + x[1]): 104 | # attributed for python class 105 | attrs = [] 106 | # columns for sqla composite 107 | columns = [] 108 | for (schema_name, composite_name, column_name, data_type, _, is_required, desc) in composite_rows: # ordinal 109 | attrs.append(column_name) 110 | column_type = type_map.get(data_type, sqltypes.NULLTYPE) 111 | nullable = not is_required 112 | column = Column(name=column_name, key=column_name, type_=column_type, nullable=nullable, comment=desc) 113 | columns.append(column) 114 | 115 | py_composite_name = snake_to_camel(composite_name, upper=True) 116 | # sqla_composite = CompositeType(py_composite_name, columns) 117 | sqla_composite = composite_type_factory(py_composite_name, columns, composite_name, schema_name) 118 | composites[(schema_name, composite_name)] = sqla_composite 119 | return composites 120 | -------------------------------------------------------------------------------- /src/nebulo/sql/reflection/views.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=unused-argument,invalid-name,line-too-long,unsubscriptable-object 2 | from __future__ import annotations 3 | 4 | from typing import List 5 | 6 | from nebulo.sql.reflection.names import rename_table 7 | from nebulo.sql.table_base import TableProtocol 8 | from parse import parse 9 | from sqlalchemy import ForeignKeyConstraint, PrimaryKeyConstraint, Table 10 | from sqlalchemy import text as sql_text 11 | 12 | 13 | class ViewMixin: 14 | """Mixin to differentiate views from tables""" 15 | 16 | is_view = True 17 | 18 | 19 | def reflect_views(engine, schema, declarative_base) -> List[TableProtocol]: 20 | """Reflect SQLAlchemy ORM Tables from the database""" 21 | 22 | sql = sql_text( 23 | """ 24 | select 25 | relname view_name, description view_comment 26 | from 27 | pg_views v 28 | left join pg_description c 29 | on true 30 | left join pg_class on c.objoid = pg_class.oid 31 | left join pg_namespace on pg_class.relnamespace = pg_namespace.oid 32 | where 33 | v.viewname = relname 34 | and nspname= :schema 35 | """ 36 | ) 37 | rows = engine.execute(sql, schema=schema).fetchall() 38 | 39 | views: List[TableProtocol] = [] 40 | 41 | for view_name, view_comment in rows: 42 | primary_key_constraint = reflect_virtual_primary_key_constraint(view_comment) 43 | foreign_key_constraints = reflect_virtual_foreign_key_constraints(view_comment) 44 | 45 | # Reflect view as base table 46 | view_tab = Table( 47 | view_name, 48 | declarative_base.metadata, 49 | schema=schema, 50 | autoload=True, 51 | autoload_with=engine, 52 | *[primary_key_constraint], 53 | *foreign_key_constraints, 54 | ) 55 | 56 | class_name = rename_table(declarative_base, view_name, view_tab) 57 | 58 | # ORM View Table 59 | view_orm = type( 60 | class_name, 61 | ( 62 | declarative_base, 63 | ViewMixin, 64 | ), 65 | {"__table__": view_tab}, 66 | ) 67 | views.append(view_orm) # type: ignore 68 | 69 | return views 70 | 71 | 72 | def reflect_virtual_primary_key_constraint(comment: str) -> PrimaryKeyConstraint: 73 | """ 74 | Reflects virtual primary key constraints from view comments 75 | 76 | Format: 77 | @primary_key (id, key2)' 78 | """ 79 | for row in comment.split("\n"): 80 | if not row.startswith("@primary_key"): 81 | continue 82 | 83 | # Remove all spaces 84 | # Ex: @primary_key(id) 85 | row = row.replace(" ", "") 86 | 87 | # Ex: (variant_id)referencespublic.variant(id) 88 | col_names = row.split("(")[1].strip().strip(")").split(",") 89 | return PrimaryKeyConstraint(*col_names) 90 | 91 | raise Exception("Views must have a @primary_key comment") 92 | 93 | 94 | def reflect_virtual_foreign_key_constraints(comment: str) -> List[ForeignKeyConstraint]: 95 | """ 96 | Reflects virtual foreign key constraints from view comments 97 | 98 | Formats: 99 | @foreign_key (variant_id, key2) references public.variant (id, key2) 100 | @foreign_key (variant_id, key2) references public.variant (id, key2) LocalName RemoteName 101 | 102 | """ 103 | templates = [ 104 | "@foreign_key ({local_col_csv}) references {schema_name}.{remote_table_name} ({remote_col_csv}) {local_name_for_remote} {remote_name_for_local}", 105 | "@foreign_key ({local_col_csv}) references {schema_name}.{remote_table_name} ({remote_col_csv})", 106 | ] 107 | 108 | foreign_keys = [] 109 | for row in comment.split("\n"): 110 | if not row.startswith("@foreign_key"): 111 | continue 112 | 113 | # Remove right spaces 114 | row = row.rstrip() 115 | 116 | for template in templates: 117 | match = parse(template, row) 118 | if match: 119 | break 120 | else: 121 | raise ValueError("invalid comment directive") 122 | 123 | # Ex: ['variant_id'] 124 | local_col_names = [x.strip() for x in match.named["local_col_csv"].split(",")] 125 | remote_col_names = [x.strip() for x in match.named["remote_col_csv"].split(",")] 126 | 127 | remote_table = match.named["schema_name"].strip() + "." + match.named["remote_table_name"].strip() 128 | 129 | remote_qualified_col_names = [f"{remote_table}.{remote_col_name}" for remote_col_name in remote_col_names] 130 | 131 | local_name = match.named.get("local_name_for_remote") 132 | remote_name = match.named.get("remote_name_for_local") 133 | 134 | foreign_key = ForeignKeyConstraint( 135 | local_col_names, 136 | remote_qualified_col_names, 137 | info={"comment": f"@name {local_name} {remote_name}" if local_name else ""}, 138 | ) 139 | foreign_keys.append(foreign_key) 140 | return foreign_keys 141 | -------------------------------------------------------------------------------- /docs/quickstart.md: -------------------------------------------------------------------------------- 1 | ## Quickstart 2 | 3 | For this example we're going to create a simple blogging platform. We'll start by creating a minimal database schema. Then we'll use nebulo to reflect our GraphQL API. Finally, we'll query from that API. 4 | 5 | ### Installation 6 | 7 | Requires: Python 3.7+ 8 | 9 | ```shell 10 | pip install nebulo 11 | ``` 12 | 13 | ### Database Setup 14 | 15 | If you don't have PostgreSQL installed locally, the following docker command creates an instance with the connection string used for the remainder of the quickstart guide. 16 | 17 | ```shell 18 | docker run --rm --name nebulo_demo -p 4443:5432 -d -e POSTGRES_DB=nebulo_db -e POSTGRES_PASSWORD=password -e POSTGRES_USER=nebulo_user -d postgres 19 | ``` 20 | 21 | Next, we need to define our database schema. We need a table for accounts, and another for blog posts. All blog posts must be associated with an author in the accounts table. Additionally, we don't want to allow users to update or delete their blog post. 22 | 23 | ```sql 24 | -- blog_schema.sql 25 | 26 | CREATE TABLE public.account ( 27 | id SERIAL PRIMARY KEY, 28 | name TEXT NOT NULL, 29 | created_at TIMESTAMP NOT NULL DEFAULT now() 30 | ); 31 | 32 | CREATE TABLE public.blog_post ( 33 | id SERIAL PRIMARY KEY, 34 | title TEXT NOT NULL, 35 | body TEXT, 36 | author_id INT REFERENCES account(id), 37 | created_at TIMESTAMP NOT NULL DEFAULT now() 38 | ); 39 | 40 | INSERT INTO account (id, name) VALUES 41 | (1, 'Oliver'), 42 | (2, 'Buddy'); 43 | 44 | INSERT INTO blog_post (id, author_id, title) VALUES 45 | (1, 1, 'First Post'), 46 | (2, 1, 'Sanitize all the things!'), 47 | (3, 2, 'To Bite or not to Bite'); 48 | 49 | 50 | COMMENT ON TABLE blog_post IS E'@exclude update, delete'; 51 | ``` 52 | 53 | Note the comment to exclude updates and deletes from blog posts. The same comment format can also be applied to exclude columns from reflection e.g. password hashes. 54 | 55 | ### GraphQL Schema 56 | 57 | Now we're ready to reflect the database schema into a GraphQL schema. The GraphQL schema document describes our API's types, fields, and relationships between entities. If you're new to GraphQL check out their [documentation](https://graphql.org/learn/) for more information. 58 | 59 | To inspect the GraphQL schema 60 | ```shell 61 | neb dump-schema -c postgresql://nebulo_user:password@localhost:4443/nebulo_db 62 | ``` 63 | where the connection string provided by `-c` is in the format `postgresql://:@:/` 64 | 65 | Which outputs the schema below 66 | 67 | ```graphql 68 | type Query { 69 | """Reads a single Account using its globally unique ID""" 70 | account(nodeId: ID!): Account 71 | 72 | """Reads and enables pagination through a set of Account""" 73 | allAccounts(first: Int, last: Int, before: Cursor, after: Cursor, condition: accountCondition): AccountConnection 74 | 75 | """Reads a single BlogPost using its globally unique ID""" 76 | blogPost(nodeId: ID!): BlogPost 77 | 78 | """Reads and enables pagination through a set of BlogPost""" 79 | allBlogPosts(first: Int, last: Int, before: Cursor, after: Cursor, condition: blogPostCondition): BlogPostConnection 80 | } 81 | 82 | type Mutation { 83 | """Creates a single Account.""" 84 | createAccount(input: CreateAccountInput!): CreateAccountPayload 85 | 86 | """Updates a single Account using its globally unique id and a patch.""" 87 | updateAccount(input: UpdateAccountInput!): UpdateAccountPayload 88 | 89 | """Delete a single Account using its globally unique id and a patch.""" 90 | deleteAccount(input: DeleteAccountInput!): DeleteAccountPayload 91 | 92 | """Creates a single BlogPost.""" 93 | createBlogPost(input: CreateBlogPostInput!): CreateBlogPostPayload 94 | } 95 | 96 | type Account implements NodeInterface { 97 | nodeId: ID! 98 | id: Int! 99 | name: String! 100 | createdAt: DateTime! 101 | 102 | """Reads and enables pagination through a set of BlogPost""" 103 | blogPostsByIdToAuthorId(first: Int, last: Int, before: Cursor, after: Cursor, condition: blogPostCondition): BlogPostConnection! 104 | } 105 | 106 | """An object with a nodeId""" 107 | interface NodeInterface { 108 | """The global id of the object.""" 109 | nodeId: ID! 110 | } 111 | 112 | scalar DateTime 113 | 114 | type BlogPostConnection { 115 | edges: [BlogPostEdge!]! 116 | pageInfo: PageInfo! 117 | totalCount: Int! 118 | } 119 | 120 | type BlogPostEdge { 121 | cursor: Cursor 122 | node: BlogPost 123 | } 124 | 125 | scalar Cursor 126 | 127 | 128 | 129 | ``` 130 | 131 | Notice that our API detected the foreign key relationship between `Account` and `BlogPost` and created `blogPostsByIdToAuthorId` and `accountByAuthorIdToId` on their base types respectively. 132 | 133 | 134 | ### Query the API 135 | 136 | To start the API server, execute `neb run` passing in a connection to the database. 137 | 138 | ```shell 139 | neb run -c postgresql://nebulo_user:password@localhost:4443/nebulo_db 140 | ``` 141 | 142 | In addition to handling GraphQL requests, nebulo also serves the [GraphiQL explorer](https://github.com/graphql/graphiql) locally at [http://localhost:5034/graphiql](http://localhost:5034/graphiql). 143 | 144 | 145 | ![graphiql image](images/graphiql.png) 146 | 147 | Enter your query in GraphiQL and click the arrow icon to execute it. 148 | 149 | You're all done! 150 | -------------------------------------------------------------------------------- /src/nebulo/gql/convert/column.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=invalid-name,unsubscriptable-object 2 | from __future__ import annotations 3 | 4 | import typing 5 | from functools import lru_cache 6 | 7 | from cachetools import LRUCache, cached 8 | from nebulo.config import Config 9 | from nebulo.gql.alias import ( 10 | Boolean, 11 | CompositeType, 12 | EnumType, 13 | Field, 14 | Float, 15 | InputField, 16 | InputObjectType, 17 | Int, 18 | List, 19 | NonNull, 20 | ScalarType, 21 | String, 22 | ) 23 | from nebulo.gql.resolve.resolvers.default import default_resolver 24 | from nebulo.sql.composite import CompositeType as SQLACompositeType 25 | from nebulo.sql.table_base import TableProtocol 26 | from nebulo.text_utils import snake_to_camel 27 | from sqlalchemy import Column, types 28 | from sqlalchemy.dialects import postgresql 29 | from sqlalchemy.sql import sqltypes 30 | from sqlalchemy.sql.type_api import TypeEngine 31 | 32 | UnknownType = ScalarType(name="UnknownString", serialize=str) 33 | DateTimeType = ScalarType(name="DateTime", serialize=str) 34 | DateType = ScalarType(name="Date", serialize=str) 35 | TimeType = ScalarType(name="Time", serialize=str) 36 | UUIDType = ScalarType(name="UUID", serialize=str) 37 | INETType = ScalarType(name="INET", serialize=str) 38 | CIDRType = ScalarType(name="CIDR", serialize=str) 39 | 40 | 41 | SQLA_TO_GQL = { 42 | types.Boolean: Boolean, 43 | sqltypes.BOOLEAN: Boolean, 44 | # Integer 45 | types.Integer: Int, 46 | types.INTEGER: Int, 47 | types.BIGINT: Int, 48 | # Float 49 | types.Float: Float, 50 | types.Numeric: Float, 51 | # required b/c json uses 32 bit integers 52 | types.BigInteger: Float, 53 | # Text 54 | types.Text: String, 55 | types.String: String, 56 | types.Unicode: String, 57 | types.UnicodeText: String, 58 | # Date 59 | types.Date: DateType, 60 | types.Time: TimeType, 61 | types.DateTime: DateTimeType, 62 | postgresql.TIMESTAMP: DateTimeType, 63 | # Other 64 | postgresql.UUID: UUIDType, 65 | postgresql.INET: INETType, 66 | postgresql.CIDR: CIDRType, 67 | } 68 | 69 | 70 | @lru_cache() 71 | def convert_type(sqla_type: TypeEngine): 72 | type_class = type(sqla_type) 73 | 74 | if issubclass(type_class, SQLACompositeType): 75 | return composite_factory(type_class) 76 | 77 | if issubclass(type_class, postgresql.base.ENUM): 78 | return enum_factory(sqla_type) # type: ignore 79 | 80 | if issubclass(type_class, postgresql.ARRAY): 81 | return array_factory(sqla_type) # type: ignore 82 | 83 | if isinstance(sqla_type, TableProtocol): 84 | from .table import table_factory 85 | 86 | return table_factory(sqla_type) 87 | 88 | return SQLA_TO_GQL.get(type_class, String) 89 | 90 | 91 | @lru_cache() 92 | def convert_column(column: Column) -> Field: 93 | """Converts a sqlalchemy column into a graphql field or input field""" 94 | gql_type = convert_type(column.type) 95 | notnull = not column.nullable 96 | return_type = NonNull(gql_type) if notnull else gql_type 97 | return Field(return_type, resolve=default_resolver) 98 | 99 | 100 | @cached(LRUCache(maxsize=9999), key=lambda x: "{x.schema}-{x.name}") # type: ignore 101 | def enum_factory(sqla_enum: typing.Type[postgresql.base.ENUM]) -> EnumType: 102 | name = Config.enum_name_mapper(sqla_enum) 103 | return EnumType(name=name, values={val: val for val in sqla_enum.enums}) 104 | 105 | 106 | @lru_cache() 107 | def array_factory(sqla_array: typing.Type[postgresql.ARRAY]) -> List: 108 | return List(convert_type(sqla_array.item_type)) 109 | 110 | 111 | @lru_cache() 112 | def composite_factory(sqla_composite: SQLACompositeType) -> CompositeType: 113 | name = snake_to_camel(sqla_composite.name, upper=True) 114 | fields = {} 115 | 116 | for column in sqla_composite.columns: 117 | column_key = str(column.key) 118 | gql_key = column_key 119 | gql_type = convert_column(column) 120 | fields[gql_key] = gql_type 121 | 122 | return_type = CompositeType(name, fields) 123 | return_type.sqla_composite = sqla_composite 124 | return return_type 125 | 126 | 127 | @lru_cache() 128 | def convert_input_type(sqla_type: typing.Type[TypeEngine[typing.Any]]): 129 | if issubclass(sqla_type, SQLACompositeType): 130 | gql_type = composite_input_factory(sqla_type) 131 | else: 132 | gql_type = SQLA_TO_GQL.get(sqla_type, String) 133 | return gql_type 134 | 135 | 136 | @lru_cache() 137 | def convert_column_to_input(column: Column) -> InputField: 138 | """Converts a sqlalchemy column into a graphql field or input field""" 139 | sqla_type = type(column.type) 140 | gql_type = convert_input_type(sqla_type) 141 | return InputField(gql_type) 142 | 143 | 144 | @lru_cache() 145 | def composite_input_factory(sqla_composite: SQLACompositeType) -> InputObjectType: 146 | name = snake_to_camel(sqla_composite.name, upper=True) + "Input" 147 | fields = {} 148 | 149 | for column in sqla_composite.columns: 150 | column_key = str(column.key) 151 | gql_key = column_key 152 | gql_field = convert_column_to_input(column) 153 | fields[gql_key] = gql_field 154 | 155 | return_type = InputObjectType(name, fields) 156 | return_type.sqla_composite = sqla_composite 157 | return return_type 158 | -------------------------------------------------------------------------------- /src/nebulo/sql/reflection/function.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=unused-argument,invalid-name,line-too-long,unsubscriptable-object 2 | from __future__ import annotations 3 | 4 | from itertools import zip_longest 5 | from typing import Any, Dict, List, Optional, Type 6 | 7 | from nebulo.exceptions import SQLParseError 8 | from sqlalchemy import cast, func 9 | from sqlalchemy import text as sql_text 10 | from sqlalchemy.engine import Engine 11 | from sqlalchemy.sql import sqltypes 12 | from sqlalchemy.sql.type_api import TypeEngine 13 | 14 | 15 | class SQLFunction: 16 | """A PostgreSQL Function 17 | 18 | INTERNAL USE ONLY 19 | """ 20 | 21 | def __init__( 22 | self, 23 | schema: str, 24 | name: str, 25 | arg_names: List[Optional[str]], 26 | arg_pg_types: List[str], 27 | arg_sqla_types: List[Type[TypeEngine[Any]]], 28 | return_sqla_type: Type[TypeEngine[Any]], 29 | return_pg_type_schema: str, 30 | return_pg_type: str, 31 | is_immutable: bool, 32 | ): 33 | if len(arg_names) != len(arg_sqla_types) != len(arg_pg_types): 34 | raise SQLParseError("SQLFunction requires same number of arg_names and sqla_types") 35 | self.schema = schema 36 | self.name = name 37 | self.arg_names = arg_names 38 | self.arg_pg_types = arg_pg_types 39 | self.arg_sqla_types = arg_sqla_types 40 | self.return_sqla_type = return_sqla_type 41 | self.return_pg_type_schema = return_pg_type_schema 42 | self.return_pg_type = return_pg_type 43 | self.is_immutable = is_immutable 44 | 45 | def to_executable(self, args: List): 46 | 47 | if len(args) != len(self.arg_pg_types): 48 | raise SQLParseError(f"Invalid number of parameters for SQLFunction {self.schema}.{self.name}") 49 | 50 | # SQLAlchemy dynamic representation of the function 51 | sqla_func = getattr(getattr(func, self.schema), self.name) 52 | # Bind and typecast user parameters 53 | call_sig = [cast(arg_value, arg_sqla_type) for arg_value, arg_sqla_type in zip(args, self.arg_sqla_types)] 54 | return sqla_func(*call_sig) 55 | 56 | 57 | def reflect_functions(engine: Engine, schema: str, type_map: Dict) -> List[SQLFunction]: 58 | """Get a list of functions available in the database""" 59 | 60 | # TODO: Support default arguments 61 | # I haven't been able to find a way to get an array of default args 62 | # but you can get a function signature including and not including 63 | # see proargdefaults, pronargdefaults, 64 | # pg_get_function_identity_arguments, get_function_arguments 65 | 66 | sql = sql_text( 67 | """ 68 | with extension_functions as ( 69 | select 70 | objid as extension_function_oid 71 | from 72 | pg_depend 73 | where 74 | -- depends on an extension 75 | deptype='e' 76 | -- is a proc/function 77 | and classid = 'pg_proc'::regclass 78 | ) 79 | 80 | select 81 | n.nspname as function_schema, 82 | p.proname as function_name, 83 | coalesce(proargnames, '{}'::text[]) arg_names, 84 | coalesce( 85 | (select array_agg((select typnamespace::regnamespace::text from pg_type where oid=type_oid)) from unnest(proargtypes) x(type_oid)), 86 | '{}'::text[] 87 | ) as arg_types_schema, 88 | coalesce( 89 | (select array_agg(type_oid::regtype::text) from unnest(proargtypes) x(type_oid)), 90 | '{}'::text[] 91 | ) arg_types, 92 | t.typnamespace::regnamespace::text as return_type_schema, 93 | t.typname as return_type, 94 | p.provolatile = 'i' is_immutable 95 | from 96 | pg_proc p 97 | left join pg_namespace n on p.pronamespace = n.oid 98 | left join pg_language l on p.prolang = l.oid 99 | left join pg_type t on t.oid = p.prorettype 100 | left join extension_functions ef on p.oid = ef.extension_function_oid 101 | where 102 | n.nspname not in ('pg_catalog', 'information_schema') 103 | and ef.extension_function_oid is null 104 | and n.nspname like :schema 105 | """ 106 | ) 107 | rows = engine.execute(sql, schema=schema).fetchall() 108 | 109 | functions: List[SQLFunction] = [] 110 | 111 | for ( 112 | func_schema, 113 | func_name, 114 | arg_names, 115 | arg_type_schemas, 116 | pg_arg_types, 117 | pg_return_type_schema, 118 | pg_return_type_name, 119 | is_immutable, 120 | ) in rows: 121 | pg_arg_names = [arg_name for arg_name, _ in zip_longest(arg_names, pg_arg_types, fillvalue=None)] 122 | sqla_arg_types = [ 123 | type_map.get(pg_type_name, sqltypes.NULLTYPE) 124 | for pg_type_schema, pg_type_name in zip(arg_type_schemas, pg_arg_types) or [] 125 | ] 126 | sqla_return_type = type_map.get(pg_return_type_name, sqltypes.NULLTYPE) 127 | 128 | function = SQLFunction( 129 | schema=func_schema, 130 | name=func_name, 131 | arg_names=pg_arg_names, 132 | arg_pg_types=[str(x) for x in pg_arg_types], # noop to silence mypy 133 | arg_sqla_types=sqla_arg_types, 134 | return_sqla_type=sqla_return_type, 135 | return_pg_type_schema=pg_return_type_schema, 136 | return_pg_type=pg_return_type_name, 137 | is_immutable=bool(is_immutable), 138 | ) 139 | functions.append(function) 140 | 141 | return functions 142 | -------------------------------------------------------------------------------- /src/nebulo/gql/resolve/resolvers/asynchronous.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import typing 4 | 5 | from flupy import flu 6 | from nebulo.config import Config 7 | from nebulo.gql.alias import FunctionPayloadType, MutationPayloadType, ObjectType, ResolveInfo, ScalarType 8 | from nebulo.gql.parse_info import parse_resolve_info 9 | from nebulo.gql.relay.node_interface import NodeIdStructure, to_node_id_sql 10 | from nebulo.gql.resolve.resolvers.claims import build_claims 11 | from nebulo.gql.resolve.transpile.mutation_builder import build_mutation 12 | from nebulo.gql.resolve.transpile.query_builder import sql_builder, sql_finalize 13 | from nebulo.sql.table_base import TableProtocol 14 | from sqlalchemy import literal_column, select 15 | 16 | 17 | async def async_resolver(_, info: ResolveInfo, **kwargs) -> typing.Any: 18 | """Awaitable GraphQL Entrypoint resolver 19 | 20 | Expects: 21 | info.context['engine'] to contain an sqlalchemy.ext.asyncio.AsyncEngine 22 | """ 23 | context = info.context 24 | engine = context["engine"] 25 | default_role = context["default_role"] 26 | jwt_claims = context["jwt_claims"] 27 | 28 | tree = parse_resolve_info(info) 29 | 30 | async with engine.begin() as trans: 31 | # Set claims for transaction 32 | if jwt_claims or default_role: 33 | claims_stmt = build_claims(jwt_claims, default_role) 34 | await trans.execute(claims_stmt) 35 | 36 | result: typing.Dict[str, typing.Any] 37 | 38 | if isinstance(tree.return_type, FunctionPayloadType): 39 | sql_function = tree.return_type.sql_function 40 | function_args = [val for key, val in tree.args["input"].items() if key != "clientMutationId"] 41 | func_call = sql_function.to_executable(function_args) 42 | 43 | # Function returning table row 44 | if isinstance(sql_function.return_sqla_type, TableProtocol): 45 | # Unpack the table row to columns 46 | return_sqla_model = sql_function.return_sqla_type 47 | core_table = return_sqla_model.__table__ 48 | func_alias = func_call.alias("named_alias") 49 | stmt = select([literal_column(c.name).label(c.name) for c in core_table.c]).select_from(func_alias) # type: ignore 50 | stmt_alias = stmt.alias() 51 | node_id_stmt = select([to_node_id_sql(return_sqla_model, stmt_alias).label("nodeId")]).select_from(stmt_alias) # type: ignore 52 | ((row,),) = await trans.execute(node_id_stmt) 53 | node_id = NodeIdStructure.from_dict(row) 54 | 55 | # Add nodeId to AST and query 56 | query_tree = next(iter([x for x in tree.fields if x.name == "result"]), None) 57 | if query_tree is not None: 58 | query_tree.args["nodeId"] = node_id 59 | base_query = sql_builder(query_tree) 60 | query = sql_finalize(query_tree.alias, base_query) 61 | ((stmt_result,),) = await trans.execute(query) 62 | else: 63 | stmt_result = {} 64 | else: 65 | stmt = select([func_call.label("result")]) 66 | (stmt_result,) = await trans.execute(stmt) 67 | 68 | maybe_mutation_id = tree.args["input"].get("clientMutationId") 69 | mutation_id_alias = next( 70 | iter([x.alias for x in tree.fields if x.name == "clientMutationId"]), 71 | "clientMutationId", 72 | ) 73 | result = {tree.alias: {**stmt_result, **{mutation_id_alias: maybe_mutation_id}}} 74 | 75 | elif isinstance(tree.return_type, MutationPayloadType): 76 | stmt = build_mutation(tree) 77 | ((row,),) = await trans.execute(stmt) 78 | node_id = NodeIdStructure.from_dict(row) 79 | 80 | maybe_mutation_id = tree.args["input"].get("clientMutationId") 81 | mutation_id_alias = next( 82 | iter([x.alias for x in tree.fields if x.name == "clientMutationId"]), 83 | "clientMutationId", 84 | ) 85 | node_id_alias = next(iter([x.alias for x in tree.fields if x.name == "nodeId"]), "nodeId") 86 | output_row_name: str = Config.table_name_mapper(tree.return_type.sqla_model) 87 | query_tree = next(iter([x for x in tree.fields if x.name == output_row_name]), None) 88 | sql_result = {} 89 | if query_tree: 90 | # Set the nodeid of the newly created record as an arg 91 | query_tree.args["nodeId"] = node_id 92 | base_query = sql_builder(query_tree) 93 | query = sql_finalize(query_tree.alias, base_query) 94 | ((sql_result,),) = await trans.execute(query) 95 | result = { 96 | tree.alias: {**sql_result, mutation_id_alias: maybe_mutation_id}, 97 | mutation_id_alias: maybe_mutation_id, 98 | node_id_alias: node_id, 99 | } 100 | 101 | elif isinstance(tree.return_type, (ObjectType, ScalarType)): 102 | base_query = sql_builder(tree) 103 | query = sql_finalize(tree.name, base_query) 104 | ((query_json_result,),) = await trans.execute(query) 105 | 106 | if isinstance(tree.return_type, ScalarType): 107 | # If its a scalar, unwrap the top level name 108 | result = flu(query_json_result.values()).first(None) 109 | else: 110 | result = query_json_result 111 | 112 | else: 113 | raise Exception("sql builder could not handle return type") 114 | 115 | # Stash result on context to enable dumb resolvers to not fail 116 | context["result"] = result 117 | return result 118 | -------------------------------------------------------------------------------- /src/nebulo/gql/parse_info.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import typing 4 | from dataclasses import dataclass 5 | 6 | from graphql.execution.execute import get_field_def 7 | from graphql.execution.values import get_argument_values 8 | from graphql.language import FieldNode, FragmentDefinitionNode, FragmentSpreadNode, InlineFragmentNode 9 | from nebulo.gql.alias import Field, List, NonNull, ObjectType, ResolveInfo, Schema 10 | 11 | __all__ = ["parse_resolve_info"] 12 | 13 | 14 | def field_to_type(field): 15 | """Recursively unwraps nested Field, List, and NonNull 16 | qualifiers to a concrete GraphQL Type""" 17 | if isinstance(field, Field): 18 | return field_to_type(field.type) 19 | if isinstance(field, List): 20 | return field_to_type(field.of_type) 21 | if isinstance(field, NonNull): 22 | return field_to_type(field.of_type) 23 | return field 24 | 25 | 26 | @dataclass(unsafe_hash=True) 27 | class ASTNode: 28 | 29 | name: str 30 | alias: str 31 | return_type: FieldNode 32 | parent: typing.Optional[ASTNode] 33 | parent_type: typing.Optional[typing.Any] 34 | args: typing.Dict[str, typing.Any] 35 | path: typing.Tuple[str, ...] 36 | fields: typing.List[ASTNode] 37 | 38 | def __init__( 39 | self, 40 | field_node: FieldNode, 41 | field_def: ObjectType, 42 | schema: Schema, 43 | parent: typing.Optional["ASTNode"], 44 | variable_values, 45 | parent_type, 46 | fragments: typing.Dict[str, FragmentDefinitionNode], 47 | ): 48 | self.name = field_node.name.value 49 | 50 | # A connection/edge/etc class 51 | field_def = get_field_def(schema, parent_type, self.name) 52 | 53 | _args = get_argument_values(type_def=field_def, node=field_node, variable_values=variable_values) 54 | 55 | selection_set = field_node.selection_set 56 | field_type = field_to_type(field_def) 57 | 58 | self.alias = (field_node.alias.value if field_node.alias else None) or field_node.name.value 59 | self.return_type = field_type 60 | self.parent: typing.Optional[ASTNode] = parent 61 | self.parent_type = parent_type 62 | self.args: typing.Dict[str, typing.Any] = _args 63 | self.path = parent.path + (self.name,) if parent is not None else ("root",) 64 | 65 | def from_selection_set(selection_set) -> typing.Generator[ASTNode, None, None]: 66 | for selection_ast in selection_set.selections: 67 | 68 | # Handle fragments 69 | if isinstance(selection_ast, FragmentSpreadNode): 70 | fragment_name = selection_ast.name.value 71 | fragment = fragments[fragment_name] 72 | fragment_selection_set = fragment.selection_set 73 | yield from from_selection_set(fragment_selection_set) 74 | 75 | elif isinstance(selection_ast, InlineFragmentNode): 76 | yield from from_selection_set(selection_ast.selection_set) 77 | 78 | else: 79 | selection_name = selection_ast.name.value 80 | if selection_name.startswith("__"): 81 | # Reserved "introspection" field handled by framework 82 | continue 83 | selection_field = field_type.fields[selection_name] 84 | yield ASTNode( 85 | field_node=selection_ast, 86 | field_def=selection_field, 87 | schema=schema, 88 | parent=self, 89 | variable_values=variable_values, 90 | parent_type=field_type, 91 | fragments=fragments, 92 | ) 93 | 94 | sub_fields = list(from_selection_set(selection_set)) if selection_set else [] 95 | self.fields = sub_fields 96 | 97 | def get_subfield_alias(self, path: typing.List[str]): 98 | """Searches subfields by name returning the alias 99 | for the terminating element in *path* list. If that subfield 100 | is not in the selection set, return the terminating elements name""" 101 | if len(path) < 1: 102 | raise Exception("Path must contain an element") 103 | 104 | for subfield in self.fields: 105 | if subfield.name == path[0]: 106 | if len(path) == 1: 107 | return subfield.alias or path[0] 108 | return subfield.get_subfield_alias(path[1:]) 109 | return path[-1] 110 | 111 | 112 | def parse_resolve_info(info: ResolveInfo) -> ASTNode: 113 | """Converts execution ResolveInfo into a dictionary 114 | hierarchy 115 | 116 | { 117 | "alias": *alias*, 118 | "name": *name*, 119 | "return_type": *return_type*, 120 | "args": { 121 | "first": 10 122 | }, 123 | "parent": 124 | "fields": [ 125 | # Same structure again, for each field selected 126 | # in the query 127 | { 128 | "alias": ... 129 | "name": ... 130 | } 131 | } 132 | } 133 | """ 134 | # Root info 135 | field_node = info.field_nodes[0] 136 | schema = info.schema 137 | 138 | fragments: typing.Dict[str, FragmentDefinitionNode] = info.fragments 139 | 140 | # Current field from parent 141 | parent_type = info.parent_type 142 | parent_lookup_name = field_node.name.value 143 | current_field = parent_type.fields[parent_lookup_name] 144 | parsed_info = ASTNode( 145 | field_node, 146 | current_field, 147 | schema, 148 | parent=None, 149 | variable_values=info.variable_values, 150 | parent_type=info.parent_type, 151 | fragments=fragments, 152 | ) 153 | return parsed_info 154 | -------------------------------------------------------------------------------- /src/nebulo/gql/convert/function.py: -------------------------------------------------------------------------------- 1 | """ 2 | When SQL functions are marked as immutable, they are handled as queries 3 | Otherwise, they are mutations 4 | """ 5 | from __future__ import annotations 6 | 7 | import typing 8 | from functools import lru_cache 9 | 10 | import jwt 11 | from nebulo.config import Config 12 | from nebulo.gql.alias import ( 13 | Argument, 14 | Field, 15 | FunctionInputType, 16 | FunctionPayloadType, 17 | InputObjectType, 18 | NonNull, 19 | ObjectType, 20 | ScalarType, 21 | String, 22 | TableInputType, 23 | ) 24 | from nebulo.gql.convert.column import convert_input_type, convert_type 25 | from nebulo.gql.resolve.resolvers.default import default_resolver 26 | from nebulo.sql.reflection.function import SQLFunction 27 | 28 | __all__ = ["mutable_function_entrypoint_factory", "immutable_function_entrypoint_factory"] 29 | 30 | """ 31 | # When volatile 32 | authenticate(input: AuthenticateInput!): 33 | AuthenticatePayload 34 | 35 | # When immutable 36 | toUpper(someText: String!): 37 | String 38 | """ 39 | 40 | 41 | @lru_cache() 42 | def immutable_function_entrypoint_factory( 43 | sql_function: SQLFunction, resolver: typing.Callable 44 | ) -> typing.Dict[str, Field]: 45 | """authenticate""" 46 | # TODO(OR): need seperate mapper 47 | function_name = Config.function_name_mapper(sql_function) 48 | 49 | if not sql_function.is_immutable: 50 | raise Exception(f"SQLFunction {sql_function.name} is not immutable, use mutable_function_entrypoint") 51 | 52 | gql_args = { 53 | (arg_name if arg_name else f"param{ix}"): Argument(NonNull(convert_type(arg_sqla_type))) 54 | for ix, (arg_name, arg_sqla_type) in enumerate(zip(sql_function.arg_names, sql_function.arg_sqla_types)) 55 | } 56 | 57 | return_type = convert_type(sql_function.return_sqla_type) 58 | return_type.sql_function = sql_function 59 | return_field = Field(return_type, args=gql_args, resolve=resolver, description="") 60 | 61 | return {function_name: return_field} 62 | 63 | 64 | def is_jwt_function(sql_function: SQLFunction, jwt_identifier: typing.Optional[str]): 65 | function_return_type_identifier = sql_function.return_pg_type_schema + "." + sql_function.return_pg_type 66 | return function_return_type_identifier == jwt_identifier 67 | 68 | 69 | @lru_cache() 70 | def mutable_function_entrypoint_factory( 71 | sql_function: SQLFunction, resolver: typing.Callable, jwt_secret: typing.Optional[str] = None 72 | ) -> typing.Dict[str, Field]: 73 | """authenticate""" 74 | # TODO(OR): need seperate mapper 75 | function_name = Config.function_name_mapper(sql_function) 76 | 77 | if sql_function.is_immutable: 78 | raise Exception(f"SQLFunction {sql_function.name} is immutable, use immutable_function_entrypoint") 79 | 80 | args = {"input": NonNull(function_input_type_factory(sql_function))} 81 | 82 | if jwt_secret is not None: 83 | payload = jwt_function_payload_factory(sql_function, jwt_secret) 84 | else: 85 | payload = function_payload_factory(sql_function) 86 | 87 | return { 88 | function_name: Field(payload, args=args, resolve=resolver, description=f"Call the function {function_name}.") 89 | } 90 | 91 | 92 | # Remainder relates to mutation functions 93 | 94 | 95 | @lru_cache() 96 | def function_input_type_factory(sql_function: SQLFunction) -> FunctionInputType: 97 | """AuthenticateInput!""" 98 | function_name = Config.function_type_name_mapper(sql_function) 99 | result_name = f"{function_name}Input" 100 | 101 | function_args = { 102 | (arg_name if arg_name else f"param{ix}"): NonNull(convert_input_type(arg_sqla_type)) 103 | for ix, (arg_name, arg_sqla_type) in enumerate(zip(sql_function.arg_names, sql_function.arg_sqla_types)) 104 | } 105 | 106 | attrs = {"clientMutationId": String, **function_args} 107 | return FunctionInputType(result_name, attrs, description=f"All input for the {function_name} mutation.") 108 | 109 | 110 | @lru_cache() 111 | def function_payload_factory(sql_function: SQLFunction) -> FunctionPayloadType: 112 | """CreateAccountPayload""" 113 | function_name = Config.function_type_name_mapper(sql_function) 114 | result_name = f"{function_name}Payload" 115 | 116 | # TODO(OR): handle functions with no return 117 | function_return_type = convert_type(sql_function.return_sqla_type) 118 | function_return_type.sql_function = sql_function 119 | 120 | attrs = { 121 | "clientMutationId": Field(String, resolve=default_resolver), 122 | "result": Field( 123 | function_return_type, 124 | description=f"The {result_name} that was created by this mutation.", 125 | resolve=default_resolver, 126 | ), 127 | } 128 | 129 | payload = FunctionPayloadType(result_name, attrs, description=f"The output of our create {function_name} mutation") 130 | payload.sql_function = sql_function 131 | return payload 132 | 133 | 134 | @lru_cache() 135 | def jwt_function_payload_factory(sql_function: SQLFunction, jwt_secret: str) -> FunctionPayloadType: 136 | """CreateAccountPayload""" 137 | function_name = Config.function_type_name_mapper(sql_function) 138 | result_name = f"{function_name}Payload" 139 | 140 | function_return_type = ScalarType( 141 | "JWT", 142 | serialize=lambda result: jwt.encode({k: v for k, v in result.items()}, jwt_secret, algorithm="HS256").decode( 143 | "utf-8" 144 | ), 145 | ) 146 | 147 | attrs = { 148 | "clientMutationId": Field(String, resolve=default_resolver), 149 | "result": Field( 150 | function_return_type, 151 | description=f"The {result_name} that was created by this mutation.", 152 | resolve=default_resolver, 153 | ), 154 | } 155 | 156 | payload = FunctionPayloadType(result_name, attrs, description=f"The output of our create {function_name} mutation") 157 | payload.sql_function = sql_function 158 | return payload 159 | -------------------------------------------------------------------------------- /src/nebulo/gql/alias.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-class-docstring,invalid-name 2 | import typing 3 | 4 | from graphql.language import ( 5 | InputObjectTypeDefinitionNode, 6 | InputObjectTypeExtensionNode, 7 | ObjectTypeDefinitionNode, 8 | ObjectTypeExtensionNode, 9 | ) 10 | from graphql.type import ( 11 | GraphQLArgument, 12 | GraphQLBoolean, 13 | GraphQLEnumType, 14 | GraphQLEnumValue, 15 | GraphQLField, 16 | GraphQLFieldMap, 17 | GraphQLFloat, 18 | GraphQLID, 19 | GraphQLInputFieldMap, 20 | GraphQLInputObjectType, 21 | GraphQLInt, 22 | GraphQLInterfaceType, 23 | GraphQLIsTypeOfFn, 24 | GraphQLList, 25 | GraphQLNonNull, 26 | GraphQLObjectType, 27 | GraphQLResolveInfo, 28 | GraphQLScalarType, 29 | GraphQLSchema, 30 | GraphQLString, 31 | GraphQLType, 32 | Thunk, 33 | ) 34 | from graphql.type.definition import GraphQLInputFieldOutType 35 | from nebulo.sql.composite import CompositeType as SQLACompositeType 36 | 37 | # Handle name changes from graphql-core and graphql-core-next 38 | try: 39 | from graphql.type import GraphQLInputObjectField as GraphQLInputField 40 | except ImportError: 41 | from graphql.type import GraphQLInputField 42 | 43 | Type = GraphQLType 44 | List = GraphQLList 45 | NonNull = GraphQLNonNull 46 | Argument = GraphQLArgument 47 | Boolean = GraphQLBoolean 48 | String = GraphQLString 49 | ScalarType = GraphQLScalarType 50 | ID = GraphQLID 51 | InterfaceType = GraphQLInterfaceType 52 | Int = GraphQLInt 53 | InputField = GraphQLInputField 54 | ResolveInfo = GraphQLResolveInfo 55 | EnumType = GraphQLEnumType 56 | EnumValue = GraphQLEnumValue 57 | Schema = GraphQLSchema 58 | Field = GraphQLField 59 | Float = GraphQLFloat 60 | 61 | EnumType = GraphQLEnumType 62 | 63 | 64 | class HasSQLAModel: # pylint: disable= too-few-public-methods 65 | sqla_table = None 66 | 67 | 68 | class HasSQLFunction: # pylint: disable= too-few-public-methods 69 | sql_function = None 70 | 71 | 72 | class HasSQLAComposite: # pylint: disable= too-few-public-methods 73 | sqla_composite: SQLACompositeType 74 | 75 | 76 | class ObjectType(GraphQLObjectType, HasSQLAModel): 77 | def __init__( 78 | self, 79 | name: str, 80 | fields: Thunk[GraphQLFieldMap], 81 | interfaces: typing.Optional[Thunk[typing.Collection["GraphQLInterfaceType"]]] = None, 82 | is_type_of: typing.Optional[GraphQLIsTypeOfFn] = None, 83 | extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, 84 | description: typing.Optional[str] = None, 85 | ast_node: typing.Optional[ObjectTypeDefinitionNode] = None, 86 | extension_ast_nodes: typing.Optional[typing.Collection[ObjectTypeExtensionNode]] = None, 87 | sqla_model=None, 88 | ) -> None: 89 | super().__init__( 90 | name=name, 91 | fields=fields, 92 | interfaces=interfaces, 93 | is_type_of=is_type_of, 94 | extensions=extensions, 95 | description=description, 96 | ast_node=ast_node, 97 | extension_ast_nodes=extension_ast_nodes, 98 | ) 99 | self.sqla_model = sqla_model 100 | 101 | 102 | class ConnectionType(ObjectType): 103 | pass 104 | 105 | 106 | class EdgeType(ObjectType): 107 | pass 108 | 109 | 110 | class TableType(ObjectType): 111 | pass 112 | 113 | 114 | class CompositeType(ObjectType, HasSQLAComposite): 115 | pass 116 | 117 | 118 | class MutationPayloadType(ObjectType): 119 | pass 120 | 121 | 122 | class CreatePayloadType(MutationPayloadType): 123 | pass 124 | 125 | 126 | class UpdatePayloadType(MutationPayloadType): 127 | pass 128 | 129 | 130 | class DeletePayloadType(MutationPayloadType): 131 | pass 132 | 133 | 134 | class FunctionPayloadType(MutationPayloadType, HasSQLFunction): 135 | pass 136 | 137 | 138 | class InputObjectType(GraphQLInputObjectType, HasSQLAModel): 139 | def __init__( 140 | self, 141 | name: str, 142 | fields: Thunk[GraphQLInputFieldMap], 143 | description: typing.Optional[str] = None, 144 | out_type: typing.Optional[GraphQLInputFieldOutType] = None, 145 | extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, 146 | ast_node: typing.Optional[InputObjectTypeDefinitionNode] = None, 147 | extension_ast_nodes: typing.Optional[typing.Collection[InputObjectTypeExtensionNode]] = None, 148 | sqla_model=None, 149 | ) -> None: 150 | super().__init__( 151 | name=name, 152 | fields=fields, 153 | description=description, 154 | out_type=out_type, 155 | extensions=extensions, 156 | ast_node=ast_node, 157 | extension_ast_nodes=extension_ast_nodes, 158 | ) 159 | self.sqla_model = sqla_model 160 | 161 | 162 | class CreateInputType(InputObjectType): 163 | pass 164 | 165 | 166 | class TableInputType(InputObjectType): 167 | pass 168 | 169 | 170 | class UpdateInputType(InputObjectType): 171 | pass 172 | 173 | 174 | class DeleteInputType(InputObjectType): 175 | pass 176 | 177 | 178 | class FunctionInputType(GraphQLInputObjectType): 179 | def __init__( 180 | self, 181 | name: str, 182 | fields: Thunk[GraphQLInputFieldMap], 183 | description: typing.Optional[str] = None, 184 | out_type: typing.Optional[GraphQLInputFieldOutType] = None, 185 | extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, 186 | ast_node: typing.Optional[InputObjectTypeDefinitionNode] = None, 187 | extension_ast_nodes: typing.Optional[typing.Collection[InputObjectTypeExtensionNode]] = None, 188 | sql_function=None, 189 | ) -> None: 190 | super().__init__( 191 | name=name, 192 | fields=fields, 193 | description=description, 194 | out_type=out_type, 195 | extensions=extensions, 196 | ast_node=ast_node, 197 | extension_ast_nodes=extension_ast_nodes, 198 | ) 199 | self.sql_function = sql_function 200 | -------------------------------------------------------------------------------- /src/nebulo/config.py: -------------------------------------------------------------------------------- 1 | from inspect import isclass 2 | from typing import Optional, Tuple, Type, Union 3 | 4 | from nebulo.env import EnvManager 5 | from nebulo.sql.inspect import get_comment, get_foreign_key_constraint_from_relationship, get_table_name, to_table 6 | from nebulo.sql.reflection.function import SQLFunction 7 | from nebulo.sql.reflection.views import ViewMixin 8 | from nebulo.sql.table_base import TableProtocol 9 | from nebulo.text_utils import snake_to_camel, to_plural 10 | from parse import parse 11 | from sqlalchemy import ForeignKeyConstraint, Table 12 | from sqlalchemy.dialects import postgresql 13 | from sqlalchemy.orm import RelationshipProperty 14 | from sqlalchemy.sql.schema import Column 15 | from typing_extensions import Literal 16 | 17 | ENV = EnvManager.get_environ() 18 | 19 | 20 | class Config: 21 | 22 | # SQLAlchemy connection string 23 | CONNECTION = ENV.get("NEBULO_CONNECTION") 24 | 25 | # schema name 26 | SCHEMA = ENV.get("NEBULO_SCHEMA") 27 | 28 | # SQL Composite type identifier e.g. api.jwt_key 29 | JWT_IDENTIFIER = ENV.get("NEBULO_JWT_IDENTIFIER") 30 | 31 | # JWT secret key 32 | JWT_SECRET = ENV.get("NEBULO_JWT_SECRET") 33 | 34 | # Default SQL role to use for unauthenticated users 35 | DEFAULT_ROLE = ENV.get("NEBULO_DEFAULT_ROLE") 36 | 37 | # Default number of entries per API page 38 | DEFAULT_PAGE_SIZE = int(ENV.get("NEBULO_DEFAULT_PAGE_SIZE", 20)) 39 | 40 | @staticmethod 41 | def function_name_mapper(sql_function: SQLFunction) -> str: 42 | """to_upper -> toUpper""" 43 | return snake_to_camel(sql_function.name, upper=False) 44 | 45 | @staticmethod 46 | def function_type_name_mapper(sql_function: SQLFunction) -> str: 47 | """to_upper -> ToUpper""" 48 | return snake_to_camel(sql_function.name, upper=True) 49 | 50 | @classmethod 51 | def table_type_name_mapper(cls, sqla_table: TableProtocol) -> str: 52 | table_comment = get_comment(sqla_table) 53 | 54 | # If an @name directive exists, use it 55 | maybe_name: Optional[str] = cls.read_name_directive(table_comment) 56 | if maybe_name: 57 | return maybe_name 58 | 59 | table_name = get_table_name(sqla_table) 60 | return snake_to_camel(table_name) 61 | 62 | @classmethod 63 | def table_name_mapper(cls, sqla_table: TableProtocol) -> str: 64 | """Return the type name with the first character lower case""" 65 | type_name = cls.table_type_name_mapper(sqla_table) 66 | return type_name[0].lower() + type_name[1:] 67 | 68 | @classmethod 69 | def column_name_mapper(cls, column: Column) -> str: 70 | """Return the type name with the first character lower case""" 71 | column_comment = get_comment(column) 72 | 73 | # If an @name directive exists, use it 74 | maybe_name: Optional[str] = cls.read_name_directive(column_comment) 75 | if maybe_name: 76 | return maybe_name 77 | 78 | column_key = column.key 79 | type_name = snake_to_camel(column_key) 80 | return type_name[0].lower() + type_name[1:] 81 | 82 | @staticmethod 83 | def read_name_directive(comment: str) -> Optional[str]: 84 | """Returns the name following a @name comment directive""" 85 | lines = comment.split("\n") 86 | for line in lines: 87 | line = line.strip() 88 | if line.startswith("@name"): 89 | line = line.lstrip("@name") 90 | name = line.strip() 91 | if name: 92 | return name 93 | return None 94 | 95 | @staticmethod 96 | def enum_name_mapper(enum: Type[postgresql.base.ENUM]) -> str: 97 | return snake_to_camel((enum.name or "") + "_enum", upper=True) 98 | 99 | @staticmethod 100 | def read_foreign_key_comment_directive(comment: str) -> Tuple[Optional[str], Optional[str]]: 101 | """Accepts a comment from a foreign key backing a sqlalchemy RelationshipProperty and searches for 102 | a comment directive defining what the names for each side of the relationship should be for the local 103 | and remote graphql models on either side. 104 | 105 | Template: 106 | @name {local_name_for_remote} {remote_name_for_local} 107 | 108 | Example: 109 | Given: 110 | ALTER TABLE public.address 111 | ADD CONSTRAINT fk_address_to_person 112 | FOREIGN KEY (address_id) 113 | REFERENCES public.person (id); 114 | 115 | The comment 116 | 117 | COMMENT ON CONSTRAINT fk_address_to_person 118 | ON public.addresses 119 | IS E'@name Person Addresses'; 120 | 121 | Would give: 122 | - The "Person" graphql model an "Addresses" field 123 | - The "Adress" graphql model a "Person" field 124 | """ 125 | 126 | template = "@name {local_name_for_remote} {remote_name_for_local}" 127 | 128 | lines = comment.split("\n") 129 | for line in lines: 130 | line = line.strip() 131 | match = parse(template, line) 132 | if match is not None: 133 | local_name, remote_name = match.named.values() 134 | return ( 135 | local_name, 136 | remote_name, 137 | ) 138 | return None, None 139 | 140 | @classmethod 141 | def relationship_name_mapper(cls, relationship: RelationshipProperty) -> str: 142 | # Check the foreign key backing the relationship for a comment defining its name 143 | backing_fkey = get_foreign_key_constraint_from_relationship(relationship) 144 | if backing_fkey is not None: 145 | comment: str = get_comment(backing_fkey) 146 | local_name, remote_name = cls.read_foreign_key_comment_directive(comment) 147 | if local_name and remote_name: 148 | backing_fkey_parent_entity = backing_fkey.parent # type: ignore 149 | fkey_parent_table: Table = to_table(backing_fkey_parent_entity) 150 | relationship_parent_table: Table = to_table(relationship.parent) 151 | if fkey_parent_table == relationship_parent_table: 152 | return local_name 153 | return remote_name 154 | 155 | # No comment directive existed, checking if argument is string 156 | # Occurs when using human annotated sqla models vs reflection 157 | if isinstance(relationship.argument, str): 158 | return relationship.argument 159 | 160 | # No string provided name, reverting to default 161 | # Union of Mapper or ORM instance 162 | referred_cls = to_table(relationship.argument) 163 | 164 | referred_name = get_table_name(referred_cls) 165 | cardinal_name = to_plural(referred_name) if relationship.uselist else referred_name 166 | camel_name = snake_to_camel(cardinal_name, upper=False) 167 | 168 | relationship_name = ( 169 | camel_name 170 | + "By" 171 | + "And".join( 172 | snake_to_camel(local_col.name, upper=True) + "To" + snake_to_camel(remote_col.name, upper=True) 173 | for local_col, remote_col in relationship.local_remote_pairs 174 | ) 175 | ) 176 | return relationship_name 177 | 178 | # SQL Comment Config 179 | @staticmethod 180 | def _exclude_check( 181 | entity: Union[TableProtocol, Column, ForeignKeyConstraint], 182 | operation: Union[ 183 | Literal["read"], 184 | Literal["create"], 185 | Literal["update"], 186 | Literal["delete"], 187 | Literal["read_one"], 188 | Literal["read_all"], 189 | ], 190 | ) -> bool: 191 | """Shared SQL comment parsing logic for excludes""" 192 | comment: str = get_comment(entity) 193 | lines = comment.split("\n") 194 | for line in lines: 195 | if "@exclude" in line: 196 | operations_dirty = line[len("@exclude") :].split(",") 197 | operations_clean = {x.strip() for x in operations_dirty} 198 | if operation in operations_clean: 199 | return True 200 | return False 201 | 202 | ########### 203 | # Exclude # 204 | ########### 205 | 206 | @classmethod 207 | def exclude_read(cls, entity: Union[TableProtocol, Column, ForeignKeyConstraint]) -> bool: 208 | """Should the entity be excluded from reads? e.g. entity(nodeId ...) and allEntities(...)""" 209 | return cls._exclude_check(entity, "read") 210 | 211 | @classmethod 212 | def exclude_read_one(cls, entity: Union[TableProtocol, Column]) -> bool: 213 | """Should the entity be excluded from reads? e.g. entity(nodeId ...)""" 214 | return any([cls._exclude_check(entity, "read_one"), cls.exclude_read(entity)]) 215 | 216 | @classmethod 217 | def exclude_read_all(cls, entity: Union[TableProtocol, Column]) -> bool: 218 | """Should the entity be excluded from reads? e.g. allEntities(...)""" 219 | return any([cls._exclude_check(entity, "read_all"), cls.exclude_read(entity)]) 220 | 221 | @classmethod 222 | def exclude_create(cls, entity: Union[TableProtocol, Column]) -> bool: 223 | """Should the entity be excluded from create mutations?""" 224 | # Views do not support insert 225 | if isclass(entity) and issubclass(entity, ViewMixin): # type: ignore 226 | return True 227 | 228 | if isinstance(entity, Column) and entity.primary_key: 229 | # Primary keys must be server generated 230 | return True 231 | 232 | return cls._exclude_check(entity, "create") 233 | 234 | @classmethod 235 | def exclude_update(cls, entity: Union[TableProtocol, Column]) -> bool: 236 | """Should the entity be excluded from update mutations?""" 237 | 238 | # Views do not support updates 239 | if isclass(entity) and issubclass(entity, ViewMixin): # type: ignore 240 | return True 241 | 242 | if isinstance(entity, Column) and entity.primary_key: 243 | # Primary keys can not be upadted via the API 244 | return True 245 | 246 | return cls._exclude_check(entity, "update") 247 | 248 | @classmethod 249 | def exclude_delete(cls, entity: Union[TableProtocol, Column]) -> bool: 250 | """Should the entity be excluded from delete mutations?""" 251 | 252 | # Views do not support deletes 253 | if isclass(entity) and issubclass(entity, ViewMixin): # type: ignore 254 | return True 255 | 256 | return cls._exclude_check(entity, "delete") 257 | -------------------------------------------------------------------------------- /src/test/pagination_test.py: -------------------------------------------------------------------------------- 1 | from nebulo.gql.relay.cursor import CursorStructure 2 | 3 | SQL_UP = """ 4 | CREATE TABLE account ( 5 | id serial primary key, 6 | name text not null, 7 | age int not null, 8 | created_at timestamp without time zone default (now() at time zone 'utc') 9 | ); 10 | 11 | INSERT INTO account (id, name, age) VALUES 12 | (1, 'oliver', 28), 13 | (2, 'rachel', 28), 14 | (3, 'sophie', 1), 15 | (4, 'buddy', 20), 16 | (5, 'foo', 99), 17 | (6, 'bar', 5), 18 | (7, 'baz', 35); 19 | """ 20 | 21 | 22 | def test_get_cursor(client_builder): 23 | client = client_builder(SQL_UP) 24 | gql_query = f""" 25 | {{ 26 | allAccounts {{ 27 | edges {{ 28 | cursor 29 | node {{ 30 | id 31 | }} 32 | }} 33 | }} 34 | }} 35 | """ 36 | with client: 37 | resp = client.post("/", json={"query": gql_query}) 38 | assert resp.status_code == 200 39 | 40 | result = resp.json() 41 | assert result["errors"] == [] 42 | cursor = result["data"]["allAccounts"]["edges"][2]["cursor"] 43 | assert cursor is not None 44 | 45 | 46 | def test_invalid_cursor(client_builder): 47 | client = client_builder(SQL_UP) 48 | 49 | cursor = CursorStructure(table_name="wrong_name", values={"id": 1}).serialize() 50 | # Query for 1 item after the cursor 51 | gql_query = f""" 52 | {{ 53 | allAccounts(first: 1, after: "{cursor}") {{ 54 | edges {{ 55 | cursor 56 | node {{ 57 | id 58 | }} 59 | }} 60 | }} 61 | }} 62 | """ 63 | with client: 64 | resp = client.post("/", json={"query": gql_query}) 65 | assert resp.status_code == 200 66 | result = resp.json() 67 | assert "invalid" in str(result["errors"][0]).lower() 68 | 69 | 70 | def test_retrieve_1_after_cursor(client_builder): 71 | client = client_builder(SQL_UP) 72 | gql_query = f""" 73 | {{ 74 | allAccounts {{ 75 | edges {{ 76 | cursor 77 | node {{ 78 | id 79 | }} 80 | }} 81 | }} 82 | }} 83 | """ 84 | with client: 85 | resp = client.post("/", json={"query": gql_query}) 86 | assert resp.status_code == 200 87 | result = resp.json() 88 | assert result["errors"] == [] 89 | # Get a cursor to 2nd entry 90 | cursor = result["data"]["allAccounts"]["edges"][1]["cursor"] 91 | 92 | # Query for 1 item after the cursor 93 | gql_query = f""" 94 | {{ 95 | allAccounts(first: 1, after: "{cursor}") {{ 96 | edges {{ 97 | cursor 98 | node {{ 99 | id 100 | }} 101 | }} 102 | }} 103 | }} 104 | """ 105 | with client: 106 | resp = client.post("/", json={"query": gql_query}) 107 | assert resp.status_code == 200 108 | result = resp.json() 109 | assert result["errors"] == [] 110 | 111 | assert result["data"]["allAccounts"]["edges"][0]["node"]["id"] == 3 112 | 113 | 114 | def test_retrieve_1_before_cursor(client_builder): 115 | client = client_builder(SQL_UP) 116 | gql_query = f""" 117 | {{ 118 | allAccounts {{ 119 | edges {{ 120 | cursor 121 | node {{ 122 | id 123 | }} 124 | }} 125 | }} 126 | }} 127 | """ 128 | with client: 129 | resp = client.post("/", json={"query": gql_query}) 130 | assert resp.status_code == 200 131 | result = resp.json() 132 | assert result["errors"] == [] 133 | 134 | # Get a cursor to 2nd entry 135 | cursor = result["data"]["allAccounts"]["edges"][1]["cursor"] 136 | print(cursor) 137 | 138 | # Query for 1 item after the cursor 139 | gql_query = f""" 140 | {{ 141 | allAccounts(last: 1, before: "{cursor}") {{ 142 | edges {{ 143 | cursor 144 | node {{ 145 | id 146 | }} 147 | }} 148 | }} 149 | }} 150 | """ 151 | with client: 152 | resp = client.post("/", json={"query": gql_query}) 153 | assert resp.status_code == 200 154 | result = resp.json() 155 | assert result["errors"] == [] 156 | 157 | assert result["data"]["allAccounts"]["edges"][0]["node"]["id"] == 1 158 | 159 | 160 | def test_first_without_after(client_builder): 161 | client = client_builder(SQL_UP) 162 | gql_query = f""" 163 | {{ 164 | allAccounts {{ 165 | edges {{ 166 | cursor 167 | node {{ 168 | id 169 | }} 170 | }} 171 | }} 172 | }} 173 | """ 174 | with client: 175 | resp = client.post("/", json={"query": gql_query}) 176 | assert resp.status_code == 200 177 | result = resp.json() 178 | assert result["errors"] == [] 179 | 180 | # Get a cursor to 2nd entry 181 | cursor = result["data"]["allAccounts"]["edges"][1]["cursor"] 182 | print(cursor) 183 | 184 | # Query for 1 item after the cursor 185 | gql_query = f""" 186 | {{ 187 | allAccounts(first: 1) {{ 188 | edges {{ 189 | cursor 190 | node {{ 191 | id 192 | }} 193 | }} 194 | }} 195 | }} 196 | """ 197 | with client: 198 | resp = client.post("/", json={"query": gql_query}) 199 | assert resp.status_code == 200 200 | result = resp.json() 201 | assert result["errors"] == [] 202 | 203 | assert result["data"]["allAccounts"]["edges"][0]["node"]["id"] == 1 204 | 205 | 206 | def test_last_without_before(client_builder): 207 | client = client_builder(SQL_UP) 208 | gql_query = f""" 209 | {{ 210 | allAccounts {{ 211 | edges {{ 212 | cursor 213 | node {{ 214 | id 215 | }} 216 | }} 217 | }} 218 | }} 219 | """ 220 | with client: 221 | resp = client.post("/", json={"query": gql_query}) 222 | assert resp.status_code == 200 223 | result = resp.json() 224 | assert result["errors"] == [] 225 | 226 | # Get a cursor to 2nd entry 227 | cursor = result["data"]["allAccounts"]["edges"][1]["cursor"] 228 | print(cursor) 229 | 230 | # Query for 1 item after the cursor 231 | gql_query = f""" 232 | {{ 233 | allAccounts(last: 1) {{ 234 | edges {{ 235 | cursor 236 | node {{ 237 | id 238 | }} 239 | }} 240 | }} 241 | }} 242 | """ 243 | with client: 244 | resp = client.post("/", json={"query": gql_query}) 245 | assert resp.status_code == 200 246 | result = resp.json() 247 | assert result["errors"] == [] 248 | 249 | assert result["data"]["allAccounts"]["edges"][0]["node"]["id"] == 7 250 | 251 | 252 | def test_pagination_order(client_builder): 253 | client = client_builder(SQL_UP) 254 | gql_query = f""" 255 | {{ 256 | allAccounts {{ 257 | edges {{ 258 | cursor 259 | node {{ 260 | id 261 | }} 262 | }} 263 | }} 264 | }} 265 | """ 266 | with client: 267 | resp = client.post("/", json={"query": gql_query}) 268 | assert resp.status_code == 200 269 | result = resp.json() 270 | assert result["errors"] == [] 271 | 272 | # Get a cursor to 2nd entry 273 | # Cursor for the "rachel" entry 274 | limit = 2 275 | after_cursor = result["data"]["allAccounts"]["edges"][1]["cursor"] 276 | # Cursor for the "buddy" entry, id = 4 277 | before_cursor = result["data"]["allAccounts"]["edges"][3]["cursor"] 278 | 279 | # Query for 2 rows after the cursor 280 | gql_query = f""" 281 | {{ 282 | allAccounts(first: {limit}, after: "{after_cursor}") {{ 283 | edges {{ 284 | cursor 285 | node {{ 286 | id 287 | name 288 | }} 289 | }} 290 | }} 291 | }} 292 | """ 293 | with client: 294 | resp = client.post("/", json={"query": gql_query}) 295 | assert resp.status_code == 200 296 | result = resp.json() 297 | assert result["errors"] == [] 298 | 299 | after_result = [x["node"]["name"] for x in result["data"]["allAccounts"]["edges"]] 300 | assert after_result == ["sophie", "buddy"] 301 | 302 | gql_query = f""" 303 | {{ 304 | allAccounts(last: {limit}, before: "{before_cursor}") {{ 305 | edges {{ 306 | cursor 307 | node {{ 308 | id 309 | name 310 | }} 311 | }} 312 | }} 313 | }} 314 | """ 315 | with client: 316 | resp = client.post("/", json={"query": gql_query}) 317 | assert resp.status_code == 200 318 | result = resp.json() 319 | assert result["errors"] == [] 320 | 321 | assert [x["node"]["name"] for x in result["data"]["allAccounts"]["edges"]] == ["rachel", "sophie"] 322 | 323 | 324 | def test_invalid_pagination_params(client_builder): 325 | client = client_builder(SQL_UP) 326 | cursor = CursorStructure(table_name="not used", values={"id": 1}).serialize() 327 | # First with Before 328 | gql_query = f""" 329 | {{ 330 | allAccounts(first: 1, before: "{cursor}") {{ 331 | edges {{ 332 | cursor 333 | node {{ 334 | id 335 | }} 336 | }} 337 | }} 338 | }} 339 | """ 340 | with client: 341 | resp = client.post("/", json={"query": gql_query}) 342 | assert resp.status_code == 200 343 | result = resp.json() 344 | assert result["errors"] != [] 345 | 346 | # Last with After 347 | gql_query = f""" 348 | {{ 349 | allAccounts(last: 1, after: "{cursor}") {{ 350 | edges {{ 351 | cursor 352 | node {{ 353 | id 354 | }} 355 | }} 356 | }} 357 | }} 358 | """ 359 | with client: 360 | resp = client.post("/", json={"query": gql_query}) 361 | assert resp.status_code == 200 362 | result = resp.json() 363 | assert result["errors"] != [] 364 | 365 | # First and Last 366 | gql_query = f""" 367 | {{ 368 | allAccounts(first: 1, last: 1) {{ 369 | edges {{ 370 | cursor 371 | node {{ 372 | id 373 | }} 374 | }} 375 | }} 376 | }} 377 | """ 378 | with client: 379 | resp = client.post("/", json={"query": gql_query}) 380 | assert resp.status_code == 200 381 | result = resp.json() 382 | assert result["errors"] != [] 383 | 384 | # Before and After 385 | gql_query = f""" 386 | {{ 387 | allAccounts(before: "{cursor}", after: "{cursor}") {{ 388 | edges {{ 389 | cursor 390 | node {{ 391 | id 392 | }} 393 | }} 394 | }} 395 | }} 396 | """ 397 | with client: 398 | resp = client.post("/", json={"query": gql_query}) 399 | assert resp.status_code == 200 400 | result = resp.json() 401 | assert result["errors"] != [] 402 | --------------------------------------------------------------------------------