├── .github └── workflows │ ├── compile.yml │ ├── docgen.yml │ ├── lint.yml │ └── test.yml ├── .gitignore ├── .luacheckrc ├── .luarc.json ├── .stylua.toml ├── ARCHITECTURE.md ├── LICENSE ├── README.md ├── assets └── screenshot.jpg ├── ci ├── build.sh ├── publish.sh ├── target-matrix.sh └── targets.json ├── dbee ├── adapters │ ├── adapters.go │ ├── bigquery.go │ ├── bigquery_driver.go │ ├── bigquery_driver_test.go │ ├── clickhouse.go │ ├── clickhouse_driver.go │ ├── databricks.go │ ├── databricks_driver.go │ ├── databricks_driver_test.go │ ├── databricks_test.go │ ├── duck.go │ ├── duck_driver.go │ ├── duck_driver_test.go │ ├── mongo.go │ ├── mongo_driver.go │ ├── mysql.go │ ├── mysql_driver.go │ ├── oracle.go │ ├── oracle_driver.go │ ├── postgres.go │ ├── postgres_driver.go │ ├── redis.go │ ├── redis_driver.go │ ├── redis_test.go │ ├── redshift.go │ ├── redshift_driver.go │ ├── sqlite.go │ ├── sqlite_driver.go │ ├── sqlserver.go │ └── sqlserver_driver.go ├── core │ ├── builders │ │ ├── client.go │ │ ├── client_options.go │ │ ├── columns.go │ │ ├── next.go │ │ ├── next_test.go │ │ └── result.go │ ├── call.go │ ├── call_archive.go │ ├── call_state.go │ ├── call_test.go │ ├── connection.go │ ├── connection_params.go │ ├── expand.go │ ├── expand_test.go │ ├── format │ │ ├── csv.go │ │ └── json.go │ ├── mock │ │ ├── adapter.go │ │ ├── adapter_options.go │ │ ├── result.go │ │ └── result_options.go │ ├── result.go │ ├── result_test.go │ └── types.go ├── endpoints.go ├── go.mod ├── go.sum ├── handler │ ├── call_log.go │ ├── event_bus.go │ ├── format_table.go │ ├── handler.go │ ├── marshal.go │ ├── output_buffer.go │ └── output_yank.go ├── main.go ├── plugin │ ├── logger.go │ ├── manifest.go │ └── plugin.go └── tests │ ├── README.md │ ├── integration │ ├── bigquery_integration_test.go │ ├── clickhouse_integration_test.go │ ├── docs.go │ ├── duckdb_integration_test.go │ ├── mysql_integration_test.go │ ├── oracle_integration_test.go │ ├── postgres_integration_test.go │ ├── redshift_integration_test.go │ ├── sqlite_integration_test.go │ └── sqlserver_integration_test.go │ ├── testdata │ ├── bigquery_seed.yaml │ ├── clickhouse_seed.sql │ ├── duckdb_seed.sql │ ├── mysql_seed.sql │ ├── oracle_seed.sql │ ├── postgres_seed.sql │ ├── sqlite_seed.sql │ └── sqlserver_seed.sql │ └── testhelpers │ ├── bigquery.go │ ├── clickhouse.go │ ├── duckdb.go │ ├── helper.go │ ├── mysql.go │ ├── oracle.go │ ├── postgres.go │ ├── sqlite.go │ └── sqlserver.go ├── doc ├── dbee-reference.txt └── dbee.txt ├── lua ├── dbee.lua └── dbee │ ├── api │ ├── __register.lua │ ├── core.lua │ ├── init.lua │ ├── state.lua │ └── ui.lua │ ├── config.lua │ ├── doc.lua │ ├── handler │ ├── __events.lua │ └── init.lua │ ├── health.lua │ ├── install │ ├── __manifest.lua │ └── init.lua │ ├── layouts │ ├── init.lua │ └── tools.lua │ ├── sources.lua │ ├── ui │ ├── call_log.lua │ ├── common │ │ ├── floats.lua │ │ └── init.lua │ ├── drawer │ │ ├── convert.lua │ │ ├── expansion.lua │ │ ├── init.lua │ │ └── menu.lua │ ├── editor │ │ ├── init.lua │ │ └── welcome.lua │ └── result │ │ ├── init.lua │ │ └── progress.lua │ └── utils.lua └── plugin └── dbee.lua /.github/workflows/docgen.yml: -------------------------------------------------------------------------------- 1 | name: Documentation Generation 2 | 3 | on: 4 | pull_request: 5 | branches: [master] 6 | push: 7 | branches: [master] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref_name }}-${{ github.event.pull_request.number || github.sha }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | readme-docs: 15 | runs-on: ubuntu-22.04 16 | name: Generate Docs from Readme 17 | env: 18 | TEMP_README: "__temp_readme.md" 19 | steps: 20 | - uses: actions/checkout@v4 21 | - name: Prepare markdown file 22 | run: | 23 | TEMP_CONFIG="$(mktemp)" 24 | # Retrieve default config and put it in a temp file. 25 | { 26 | echo '```lua' 27 | awk '/DOCGEN_END/{f=0} f; /DOCGEN_START/{f=1}' lua/dbee/config.lua 28 | echo '```' 29 | } > "$TEMP_CONFIG" 30 | # Insert the default config between DOCGEN_CONFIG tags in the README. 31 | # And remove stuff between DOCGEN_IGNORE_START and DOCGEN_IGNORE_END tags from README. 32 | { 33 | sed -e ' 34 | /DOCGEN_CONFIG_START/,/DOCGEN_CONFIG_END/!b 35 | /DOCGEN_CONFIG_START/r '"$TEMP_CONFIG"' 36 | /DOCGEN_CONFIG_END:/!d 37 | ' <(sed '/DOCGEN_IGNORE_START/,/DOCGEN_IGNORE_END/d' README.md) 38 | cat ARCHITECTURE.md 39 | } > "$TEMP_README" 40 | - name: Generate vimdoc 41 | uses: kdheepak/panvimdoc@v3.0.6 42 | with: 43 | vimdoc: dbee 44 | pandoc: "${{ env.TEMP_README }}" 45 | toc: true 46 | description: "Database Client for NeoVim" 47 | treesitter: true 48 | ignorerawblocks: true 49 | docmappingprojectname: false 50 | - name: Commit the Generated Help 51 | uses: EndBug/add-and-commit@v9 52 | if: github.event_name == 'push' 53 | with: 54 | add: doc/dbee.txt 55 | author_name: Github Actions 56 | author_email: actions@github.com 57 | message: "[docgen] Update doc/dbee.txt" 58 | pull: --rebase --autostash 59 | 60 | reference-docs: 61 | name: Generate Reference Docs 62 | runs-on: ubuntu-22.04 63 | steps: 64 | - uses: actions/checkout@v4 65 | - name: Generating help 66 | shell: bash 67 | run: | 68 | curl -Lq https://github.com/numToStr/lemmy-help/releases/latest/download/lemmy-help-x86_64-unknown-linux-gnu.tar.gz | tar xz 69 | ./lemmy-help lua/dbee.lua lua/dbee/{config,doc,sources,layouts/init,api/core,api/ui}.lua --expand-opt > doc/dbee-reference.txt 70 | - name: Commit the Generated Docs 71 | uses: EndBug/add-and-commit@v9 72 | if: github.event_name == 'push' 73 | with: 74 | add: doc/dbee-reference.txt 75 | author_name: Github Actions 76 | author_email: actions@github.com 77 | message: "[docgen] Update doc/dbee-reference.txt" 78 | pull: --rebase --autostash 79 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Linting and Style Checking 2 | 3 | on: 4 | pull_request: 5 | branches: [master] 6 | push: 7 | branches: [master] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref_name }}-${{ github.event.pull_request.number || github.sha }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | luacheck: 15 | runs-on: ubuntu-22.04 16 | name: Lint Lua Code 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Luacheck Linter 20 | uses: lunarmodules/luacheck@v0 21 | with: 22 | args: lua/ 23 | 24 | stylua: 25 | runs-on: ubuntu-22.04 26 | name: Check Lua Style 27 | steps: 28 | - uses: actions/checkout@v4 29 | - name: Lua Style Check 30 | uses: JohnnyMorganz/stylua-action@v4 31 | with: 32 | version: v0.17 33 | token: ${{ secrets.GITHUB_TOKEN }} 34 | args: --color always --check lua/ 35 | 36 | markdown-format: 37 | runs-on: ubuntu-22.04 38 | name: Check Markdown Format 39 | steps: 40 | - uses: actions/checkout@v4 41 | - name: Python Setup 42 | uses: actions/setup-python@v5 43 | with: 44 | python-version: "3.10" 45 | - name: Install mdformat 46 | run: | 47 | pip install mdformat-gfm 48 | - name: Markdown Style Check 49 | run: | 50 | mdformat --number --wrap 100 --check README.md ARCHITECTURE.md 51 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Testing 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [master] 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.ref_name }}-${{ github.event.pull_request.number || github.sha }} 10 | cancel-in-progress: true 11 | 12 | defaults: 13 | run: 14 | working-directory: dbee 15 | 16 | env: 17 | GO_VERSION: "1.23.x" 18 | 19 | jobs: 20 | go-unit-test: 21 | runs-on: ubuntu-22.04 22 | name: Go Unit Test 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Setup Go v${{ env.GO_VERSION }} 26 | uses: actions/setup-go@v5 27 | with: 28 | go-version: v${{ env.GO_VERSION }} 29 | check-latest: true 30 | cache-dependency-path: ./dbee/go.sum 31 | # exclude tests folder (not used for unit tests) 32 | - name: Run Unit Tests 33 | run: go test $(go list ./... | grep -v tests) -v 34 | 35 | bootstrap-testcontainers: 36 | runs-on: ubuntu-22.04 37 | name: Bootstrap Testcontainers 38 | outputs: 39 | matrix: ${{ steps.generate-matrix.outputs.matrix }} 40 | steps: 41 | - uses: actions/checkout@v4 42 | - id: generate-matrix 43 | run: | 44 | # create a JSON object with the adapter names to bootstrap matrix with. 45 | matrix=$(find tests/integration -name '*_integration_test.go' -exec basename {} \; \ 46 | | sed 's/_integration_test.go//' \ 47 | | jq -scR 'split("\n") | map(select(length > 0)) | {adapter: .}') 48 | echo "matrix=$matrix" | tee $GITHUB_OUTPUT 49 | 50 | go-integration-test: 51 | needs: bootstrap-testcontainers 52 | runs-on: ubuntu-22.04 53 | timeout-minutes: 10 54 | env: 55 | TESTCONTAINERS_RYUK_DISABLED: true 56 | strategy: 57 | fail-fast: false 58 | matrix: ${{ fromJSON(needs.bootstrap-testcontainers.outputs.matrix) }} 59 | name: Go Integration Test (${{ matrix.adapter }}) 60 | steps: 61 | - uses: actions/checkout@v4 62 | - name: Setup Go v${{ env.GO_VERSION }} 63 | uses: actions/setup-go@v5 64 | with: 65 | go-version: ${{ env.GO_VERSION }} 66 | check-latest: true 67 | cache-dependency-path: ./dbee/go.sum 68 | - name: Run Integration Tests 69 | run: sudo go test ./tests/integration/${{ matrix.adapter }}_integration_test.go -v 70 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | doc/tags 2 | *.log 3 | -------------------------------------------------------------------------------- /.luacheckrc: -------------------------------------------------------------------------------- 1 | -- Show error codes in the output 2 | codes = true 3 | 4 | -- Disable unused argument warning for "self" 5 | self = false 6 | 7 | ignore = { 8 | "122", -- Indirectly setting a readonly global 9 | "631", -- line too long 10 | } 11 | 12 | -- Per file ignores 13 | files["lua/projector/contract/*"] = { ignore = { "212" } } -- Ignore unused argument warning for interfaces 14 | 15 | -- Global objects defined by the C code 16 | read_globals = { 17 | "vim", 18 | } 19 | -------------------------------------------------------------------------------- /.luarc.json: -------------------------------------------------------------------------------- 1 | { 2 | "workspace.checkThirdParty": false 3 | } 4 | -------------------------------------------------------------------------------- /.stylua.toml: -------------------------------------------------------------------------------- 1 | column_width = 120 2 | line_endings = "Unix" 3 | indent_type = "Spaces" 4 | indent_width = 2 5 | quote_style = "AutoPreferDouble" 6 | call_parentheses = "NoSingleTable" 7 | -------------------------------------------------------------------------------- /ARCHITECTURE.md: -------------------------------------------------------------------------------- 1 | # DBee Architecture Overview 2 | 3 | The plugin is created from 2 parts: 4 | 5 | - Go backend that interacts with databases, 6 | - Lua frontend which wraps the backend with neat nvim integration. 7 | 8 | These two parts should have clearly defined borders and not "leak" responsibilities. 9 | 10 | ## Lua Architecture 11 | 12 | The following diagram shows a high level overview of lua packages. Note that a lot of connections 13 | are removed for diagram clarity. 14 | 15 | ``` 16 | ui 17 | ┌──────────┐ 18 | ┌─────►│ Result ├──────┐ 19 | │ └──────────┘ │ 20 | │ │ 21 | │ ┌──────────┐ │ 22 | ├─────►│ Editor ├──────┤ core 23 | ┌──────────┐ ┌──────────┐ │ └──────────┘ │ ┌──────────┐ 24 | │ API ├──►│ entry ├──┤ ├────►│ Handler ├───► Go 25 | └──────────┘ └──────────┘ │ ┌──────────┐ │ └──────────┘ 26 | ├─────►│ Drawer ├──────┤ 27 | │ └──────────┘ │ 28 | │ │ 29 | │ ┌──────────┐ │ 30 | └─────►│ Call Log ├──────┘ 31 | └──────────┘ 32 | 33 | ┌──────────┐ ┌──────────┐ 34 | │ sources │ │ layouts │ 35 | └──────────┘ └──────────┘ 36 | ┌──────────┐ 37 | │ install │ 38 | └──────────┘ 39 | ``` 40 | 41 | Description: 42 | 43 | - The "dbee" package consists of 2 major functional packages, ui and handler (core). 44 | 45 | - `handler` or core package is a wrapper around the go backend handler. The only extra thing lua 46 | handler does on top is information about sources. 47 | - `ui` package consists of the following packages: 48 | - `Drawer` represents the tree view. It uses the handler and editor to provide the view of 49 | connections and notes. 50 | - `Editor` represents the notepad view. It manages notes per namespace (namespace is an 51 | arbitrary name - Drawer uses it to have connection-local notes). 52 | - `Result` represents the results view. 53 | - `Call Log` represents the history of calls vie view and supports managing past calls. 54 | 55 | - `install` package is independent of the other packages and is used for installation of the 56 | compiled go binary using the manifest generated by the CI pipeline. 57 | 58 | - `sources` package holds an implementation of some of the most common sources. 59 | 60 | - `layouts` package holds the implementation of the default window layout. 61 | 62 | ## Go Architecture 63 | 64 | As We said, the Go backend is accessed exclusively through `handler` in lua. The way the 65 | communication workd both ways is that lua can call the handler method directly and go triggers 66 | events that lua then listens to. 67 | 68 | An example of this event based message passing is executing a query on connection: 69 | 70 | - lua registers an event listener to display results. 71 | - lua calls go execute method, which returns call details immediately. 72 | - lua then waits for call to yield some results, to display them. 73 | 74 | One way of looking at the handler package in go is that it's just an implementation specific use 75 | case of the `core` go package. 76 | 77 | ### Core package 78 | 79 | Here is the [godoc](https://pkg.go.dev/github.com/kndndrj/nvim-dbee/dbee/core) of the core package. 80 | 81 | Main construct is a `Connection`. It takes the parameters to connect to the database and an adapter 82 | for the database. Adapter is a provider for specific databases, which can return a database driver 83 | and returns common database queries (helpers). 84 | 85 | Then the connection can execute queries using a driver. This procudes a `Call`. A call represents a 86 | single call to the database and holds it's state. 87 | 88 | Database call returns a result, which transforms the iterator returned from driver to different 89 | formats using a formatter. 90 | 91 | #### Adapters package 92 | 93 | One of the subpackages of core package is `adapters`. It contains implemetations of multiple 94 | database drivers and adapters. One special thing it does is that it has it's own method to create a 95 | connection. This is done so that individual adapters can register themselves on startup in their 96 | init functions (so that we can exclude some adapters on certain architectures/os-es). 97 | 98 | #### Builders package 99 | 100 | Another subpackage of `core`, which holds convenience functions for creating some of the most used 101 | constructs. An example are multiple implementations of the `ResultStream` interface. 102 | -------------------------------------------------------------------------------- /assets/screenshot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kndndrj/nvim-dbee/9656fc59841291e9dbd2f3b50b1cb4c77d9fea79/assets/screenshot.jpg -------------------------------------------------------------------------------- /ci/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # builds a go binary with the provided args 4 | set -e 5 | 6 | # args: 7 | goos="" # -o GOOS value 8 | goarch="" # -a GOARCH value 9 | crossarch="" # -c cgo cross compilation target 10 | buildtags="" # -b build arguments 11 | cgo=0 # -e cgo enabled (true or false) 12 | output="" # -p output path 13 | 14 | while getopts 'o:a:c:b:p:e:' opt; do 15 | case "$opt" in 16 | o) 17 | goos="$OPTARG" ;; 18 | a) 19 | goarch="$OPTARG" ;; 20 | c) 21 | crossarch="$OPTARG" ;; 22 | b) 23 | buildtags="$OPTARG" ;; 24 | p) 25 | output="$OPTARG" ;; 26 | e) 27 | [ "$OPTARG" = "true" ] && cgo=1 ;; 28 | *) 29 | # ignore invalid args 30 | echo "invalid flag: $opt" ;; 31 | esac 32 | done 33 | 34 | # check if cross platform is specified 35 | if [ -n "$crossarch" ]; then 36 | cc="zig cc -target $crossarch" 37 | cxx="zig c++ -target $crossarch" 38 | fi 39 | 40 | # Compile 41 | export CGO_ENABLED="$cgo" 42 | export CC="$cc" 43 | export CXX="$cxx" 44 | export GOOS="$goos" 45 | export GOARCH="$goarch" 46 | 47 | go build -tags="$buildtags" -o "$output" 48 | -------------------------------------------------------------------------------- /ci/publish.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # publishes the compiled binary to the bucket repository 4 | set -e 5 | 6 | # args: 7 | files="" # -a path to file(s) to add 8 | branch="" # -b branch name 9 | publish_user="" # -u publisher's username 10 | publish_token="" # -t publisher's token 11 | repo="" # -r short repo name - e.g. "owner/repo" 12 | message="" # -m commit message 13 | 14 | while getopts 'a:b:u:t:r:m:' opt; do 15 | case "$opt" in 16 | a) 17 | for f in $OPTARG; do 18 | files="$files $(realpath "$f")" 19 | done ;; 20 | b) 21 | branch="$OPTARG" ;; 22 | u) 23 | publish_user="$OPTARG" ;; 24 | t) 25 | publish_token="$OPTARG" ;; 26 | r) 27 | repo="$OPTARG" ;; 28 | m) 29 | message="$OPTARG" ;; 30 | *) 31 | # ignore invalid args 32 | echo "invalid flag: $opt" ;; 33 | esac 34 | done 35 | 36 | # validate input 37 | for var in "$files" "$branch" "$publish_user" "$publish_token" "$repo"; do 38 | if [ -z "$var" ]; then 39 | echo "some of the variables are not provided!" 40 | exit 1 41 | fi 42 | done 43 | 44 | # prepare temporary directory 45 | tempdir="$(mktemp -d)" 46 | cd "$tempdir" || exit 1 47 | 48 | # clone 49 | echo "cloning bucket repository" 50 | git clone https://"$publish_user":"$publish_token"@github.com/"$repo" bucket 51 | cd bucket || exit 1 52 | git config user.name "Github Actions" 53 | git config user.email "actions@github.com" 54 | 55 | # new branch 56 | git checkout -b "$branch" 2>/dev/null || git checkout "$branch" 57 | 58 | # add files to ./bin/ subdir 59 | echo "applying changes" 60 | mkdir -p bin/ 61 | # copy files 62 | for f in $files; do 63 | cp -r "$f" bin/ 64 | done 65 | git add bin/ 66 | [ -z "$message" ] && message="added $files" 67 | git commit -m "$message" 68 | 69 | # try publishing 10 times 70 | echo "trying to push to bucket repository..." 71 | for i in 1 2 3 4 5 6 7 8 9 10 11; do 72 | echo "attempt $i/10" 73 | if (git push -u origin "$branch"); then 74 | echo "push succeeded after $i attempts" 75 | break 76 | fi 77 | 78 | git pull origin "$branch" --rebase || true 79 | 80 | if [ "$i" -eq 11 ]; then 81 | echo "push failed after 10 attempts" 82 | exit 1 83 | fi 84 | 85 | sleep 3 86 | done 87 | -------------------------------------------------------------------------------- /ci/target-matrix.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # assembles github actions matrix from targets list 3 | 4 | default_buildplatform="ubuntu-latest" 5 | 6 | # handle primary platforms flag (used in pull_request CI/CD) 7 | if [ "$1" = "--primary" ]; then 8 | primary_filter='[.[] | select(.primary == true)]' 9 | else 10 | primary_filter='.' 11 | fi 12 | # strip comments 13 | targets="$(sed '/^\s*\/\//d;s/\/\/.*//' "$(dirname "$0")/targets.json")" 14 | 15 | # filter for primary platforms if requested 16 | targets="$(echo "$targets" | jq "$primary_filter")" 17 | 18 | # assign a default buildplatform 19 | targets="$(echo "$targets" | jq 'map( 20 | . + if has("buildplatform") then 21 | {buildplatform} 22 | else 23 | {buildplatform: "'"$default_buildplatform"'"} 24 | end 25 | )')" 26 | 27 | # echo the matrix (remove newlines) 28 | echo 'matrix={"include":'"$targets"'}' | tr -d '\n' 29 | -------------------------------------------------------------------------------- /dbee/adapters/adapters.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "fmt" 7 | "text/template" 8 | 9 | "github.com/kndndrj/nvim-dbee/dbee/core" 10 | ) 11 | 12 | var ( 13 | errNoValidTypeAliases = errors.New("no valid type aliases provided") 14 | ErrUnsupportedTypeAlias = errors.New("no driver registered for provided type alias") 15 | ) 16 | 17 | var _ core.Adapter = (*wrappedAdapter)(nil) 18 | 19 | // wrappedAdapter is returned from Mux and adds extra helpers to internal adapter. 20 | type wrappedAdapter struct { 21 | adapter core.Adapter 22 | extraHelpers map[string]*template.Template 23 | } 24 | 25 | // registeredAdapters holds implemented adapters - specific adapters register themselves in their init functions. 26 | // The main reason is to be able to compile the binary without unsupported os/arch of specific drivers. 27 | var registeredAdapters = make(map[string]*wrappedAdapter) 28 | 29 | // register registers a new adapter for specific database 30 | func register(adapter core.Adapter, aliases ...string) error { 31 | if len(aliases) < 1 { 32 | return errNoValidTypeAliases 33 | } 34 | 35 | value := &wrappedAdapter{ 36 | adapter: adapter, 37 | } 38 | 39 | invalidCount := 0 40 | for _, alias := range aliases { 41 | if alias == "" { 42 | invalidCount++ 43 | continue 44 | } 45 | registeredAdapters[alias] = value 46 | } 47 | 48 | if invalidCount == len(aliases) { 49 | return errNoValidTypeAliases 50 | } 51 | 52 | return nil 53 | } 54 | 55 | // Mux is an interface to all internal adapters. 56 | type Mux struct{} 57 | 58 | func (*Mux) GetAdapter(typ string) (core.Adapter, error) { 59 | value, ok := registeredAdapters[typ] 60 | if !ok { 61 | return nil, ErrUnsupportedTypeAlias 62 | } 63 | 64 | return value, nil 65 | } 66 | 67 | func (*Mux) AddAdapter(typ string, adapter core.Adapter) error { 68 | return register(adapter, typ) 69 | } 70 | 71 | func (*Mux) AddHelpers(typ string, helpers map[string]string) error { 72 | value, ok := registeredAdapters[typ] 73 | if !ok { 74 | return ErrUnsupportedTypeAlias 75 | } 76 | 77 | if value.extraHelpers == nil { 78 | value.extraHelpers = make(map[string]*template.Template) 79 | } 80 | 81 | // new helpers have priority 82 | for k, v := range helpers { 83 | tmpl, err := template.New("helpers").Parse(v) 84 | if err != nil { 85 | return fmt.Errorf("template.New.Parse: %w", err) 86 | } 87 | 88 | value.extraHelpers[k] = tmpl 89 | } 90 | 91 | return nil 92 | } 93 | 94 | func (wa *wrappedAdapter) Connect(url string) (core.Driver, error) { 95 | return wa.adapter.Connect(url) 96 | } 97 | 98 | func (wa *wrappedAdapter) GetHelpers(opts *core.TableOptions) map[string]string { 99 | helpers := wa.adapter.GetHelpers(opts) 100 | if helpers == nil { 101 | helpers = make(map[string]string) 102 | } 103 | 104 | // extra helpers have priority 105 | for k, tmpl := range wa.extraHelpers { 106 | var out bytes.Buffer 107 | err := tmpl.Execute(&out, opts) 108 | if err != nil { 109 | continue 110 | } 111 | 112 | helpers[k] = out.String() 113 | } 114 | 115 | return helpers 116 | } 117 | 118 | // NewConnection is a wrapper around core.NewConnection that uses the internal mux for 119 | // adapter registration. 120 | func NewConnection(params *core.ConnectionParams) (*core.Connection, error) { 121 | adapter, err := new(Mux).GetAdapter(params.Expand().Type) 122 | if err != nil { 123 | return nil, fmt.Errorf("Mux.GetAdapters: %w", err) 124 | } 125 | 126 | c, err := core.NewConnection(params, adapter) 127 | if err != nil { 128 | return nil, fmt.Errorf("core.NewConnection: %w", err) 129 | } 130 | 131 | return c, nil 132 | } 133 | -------------------------------------------------------------------------------- /dbee/adapters/bigquery.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/url" 7 | 8 | "cloud.google.com/go/bigquery" 9 | "google.golang.org/api/option" 10 | "google.golang.org/api/option/internaloption" 11 | "google.golang.org/grpc" 12 | "google.golang.org/grpc/credentials/insecure" 13 | 14 | "github.com/kndndrj/nvim-dbee/dbee/core" 15 | ) 16 | 17 | // Register client 18 | func init() { 19 | _ = register(&BigQuery{}, "bigquery") 20 | } 21 | 22 | var _ core.Adapter = (*BigQuery)(nil) 23 | 24 | type BigQuery struct{} 25 | 26 | // Connect creates a [BigQuery] client connected to the project specified 27 | // in the url. The format of the url is as follows: 28 | // 29 | // bigquery://[project][?options] 30 | // 31 | // where project is optional. If not set, the project will attempt to be 32 | // detected from the credentials and current gcloud settings. 33 | // 34 | // The options query parameters map directly to [bigquery.QueryConfig] fields 35 | // using kebab-case. For example, MaxBytesBilled becomes max-bytes-billed. 36 | // 37 | // Common options include: 38 | // - credentials=path/to/creds.json: Path to credentials file 39 | // - max-bytes-billed=integer: Maximum bytes to be billed 40 | // - disable-query-cache=bool: Whether to disable query cache 41 | // - use-legacy-sql=bool: Whether to use legacy SQL 42 | // - location=string: Query location 43 | // - enable-storage-read=bool: Enable BigQuery Storage API 44 | // 45 | // For internal testing: 46 | // - endpoint=url: Custom endpoint for test containers 47 | // 48 | // If credentials are not specified, they will be located according to 49 | // the Google Default Credentials process. 50 | func (bq *BigQuery) Connect(rawURL string) (core.Driver, error) { 51 | ctx := context.Background() 52 | 53 | u, err := url.Parse(rawURL) 54 | if err != nil { 55 | return nil, err 56 | } 57 | 58 | if u.Scheme != "bigquery" { 59 | return nil, fmt.Errorf("unexpected scheme: %q", u.Scheme) 60 | } 61 | 62 | if u.Host == "" { 63 | u.Host = bigquery.DetectProjectID 64 | } 65 | 66 | options := []option.ClientOption{option.WithTelemetryDisabled()} 67 | params := u.Query() 68 | 69 | // special param to indicate we are running in testcontainer. 70 | if endpoint := params.Get("endpoint"); endpoint != "" { 71 | options = append(options, 72 | option.WithEndpoint(endpoint), 73 | option.WithGRPCDialOption(grpc.WithTransportCredentials(insecure.NewCredentials())), 74 | option.WithoutAuthentication(), 75 | internaloption.SkipDialSettingsValidation(), 76 | ) 77 | } else { 78 | callIfStringSet("credentials", params, func(file string) error { 79 | options = append(options, option.WithCredentialsFile(file)) 80 | return nil 81 | }) 82 | } 83 | 84 | bqc, err := bigquery.NewClient(ctx, u.Host, options...) 85 | if err != nil { 86 | return nil, err 87 | } 88 | 89 | client := &bigQueryDriver{c: bqc} 90 | if err = setQueryConfigFromParams(&client.QueryConfig, params); err != nil { 91 | return nil, err 92 | } 93 | 94 | if err = callIfBoolSet("enable-storage-read", params, func() error { 95 | return client.c.EnableStorageReadClient(ctx, options...) 96 | }, nil); err != nil { 97 | return nil, err 98 | } 99 | 100 | return client, nil 101 | } 102 | 103 | func (*BigQuery) GetHelpers(opts *core.TableOptions) map[string]string { 104 | return map[string]string{ 105 | "List": fmt.Sprintf("SELECT * FROM `%s` TABLESAMPLE SYSTEM (5 PERCENT)", opts.Table), 106 | "Columns": fmt.Sprintf("SELECT * FROM `%s.INFORMATION_SCHEMA.COLUMNS` WHERE TABLE_SCHEMA = '%s' AND TABLE_NAME = '%s'", opts.Schema, opts.Schema, opts.Table), 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /dbee/adapters/clickhouse.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/ClickHouse/clickhouse-go/v2" 9 | "github.com/kndndrj/nvim-dbee/dbee/core" 10 | "github.com/kndndrj/nvim-dbee/dbee/core/builders" 11 | ) 12 | 13 | // Register client 14 | func init() { 15 | _ = register(&Clickhouse{}, "clickhouse") 16 | } 17 | 18 | var _ core.Adapter = (*Clickhouse)(nil) 19 | 20 | type Clickhouse struct{} 21 | 22 | func (p *Clickhouse) Connect(url string) (core.Driver, error) { 23 | options, err := clickhouse.ParseDSN(url) 24 | if err != nil { 25 | return nil, fmt.Errorf("could not parse db connection string: %w", err) 26 | } 27 | 28 | jsonProcessor := func(a any) any { 29 | b, ok := a.([]byte) 30 | if !ok { 31 | return a 32 | } 33 | 34 | return newPostgresJSONResponse(b) 35 | } 36 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 37 | defer cancel() 38 | 39 | db := clickhouse.OpenDB(options) 40 | if err := db.PingContext(ctx); err != nil { 41 | return nil, fmt.Errorf("pinging connection failed with %v", err) 42 | } 43 | 44 | return &clickhouseDriver{ 45 | c: builders.NewClient(db, 46 | builders.WithCustomTypeProcessor("json", jsonProcessor), 47 | ), 48 | opts: options, 49 | }, nil 50 | } 51 | 52 | func (*Clickhouse) GetHelpers(opts *core.TableOptions) map[string]string { 53 | return map[string]string{ 54 | "List": fmt.Sprintf( 55 | "SELECT * FROM %q.%q LIMIT 500", 56 | opts.Schema, opts.Table, 57 | ), 58 | "Columns": fmt.Sprintf( 59 | "DESCRIBE %q.%q", 60 | opts.Schema, opts.Table, 61 | ), 62 | "Info": fmt.Sprintf( 63 | "SELECT * FROM system.tables WHERE database = '%s' AND name = '%s'", 64 | opts.Schema, opts.Table, 65 | ), 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /dbee/adapters/clickhouse_driver.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/ClickHouse/clickhouse-go/v2" 8 | "github.com/kndndrj/nvim-dbee/dbee/core" 9 | "github.com/kndndrj/nvim-dbee/dbee/core/builders" 10 | ) 11 | 12 | var ( 13 | _ core.Driver = (*clickhouseDriver)(nil) 14 | _ core.DatabaseSwitcher = (*clickhouseDriver)(nil) 15 | ) 16 | 17 | type clickhouseDriver struct { 18 | c *builders.Client 19 | opts *clickhouse.Options 20 | } 21 | 22 | func (c *clickhouseDriver) Query(ctx context.Context, query string) (core.ResultStream, error) { 23 | // run query, fallback to affected rows 24 | return c.c.QueryUntilNotEmpty(ctx, query, "select changes() as 'Rows Affected'") 25 | } 26 | 27 | func (c *clickhouseDriver) Columns(opts *core.TableOptions) ([]*core.Column, error) { 28 | return c.c.ColumnsFromQuery(` 29 | SELECT name, type 30 | FROM system.columns 31 | WHERE 32 | database='%s' AND 33 | table='%s' 34 | `, opts.Schema, opts.Table) 35 | } 36 | 37 | func (c *clickhouseDriver) Structure() ([]*core.Structure, error) { 38 | query := ` 39 | SELECT 40 | table_schema, table_name, table_type 41 | FROM information_schema.tables 42 | WHERE lower(table_schema) != 'information_schema' 43 | UNION ALL 44 | SELECT DISTINCT 45 | lower(table_schema), lower(table_name), table_type 46 | FROM information_schema.tables 47 | WHERE lower(table_schema) = 'information_schema'` 48 | 49 | rows, err := c.Query(context.TODO(), query) 50 | if err != nil { 51 | return nil, err 52 | } 53 | 54 | return core.GetGenericStructure(rows, getPGStructureType) 55 | } 56 | 57 | func (c *clickhouseDriver) Close() { 58 | c.c.Close() 59 | } 60 | 61 | func (c *clickhouseDriver) ListDatabases() (current string, available []string, err error) { 62 | query := ` 63 | SELECT currentDatabase(), schema_name 64 | FROM information_schema.schemata 65 | WHERE schema_name NOT IN (currentDatabase(), 'INFORMATION_SCHEMA') 66 | ` 67 | 68 | rows, err := c.Query(context.TODO(), query) 69 | if err != nil { 70 | return "", nil, err 71 | } 72 | 73 | for rows.HasNext() { 74 | row, err := rows.Next() 75 | if err != nil { 76 | return "", nil, err 77 | } 78 | 79 | // We know for a fact there are 2 string fields (see query above) 80 | current = row[0].(string) 81 | available = append(available, row[1].(string)) 82 | } 83 | 84 | return current, available, nil 85 | } 86 | 87 | func (c *clickhouseDriver) SelectDatabase(name string) error { 88 | oldDB := c.opts.Auth.Database 89 | c.opts.Auth.Database = name 90 | 91 | db := clickhouse.OpenDB(c.opts) 92 | if err := db.PingContext(context.Background()); err != nil { 93 | c.opts.Auth.Database = oldDB 94 | return fmt.Errorf("pinging connection failed with %v", err) 95 | } 96 | 97 | c.c.Swap(db) 98 | return nil 99 | } 100 | -------------------------------------------------------------------------------- /dbee/adapters/databricks.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "net/url" 7 | 8 | _ "github.com/databricks/databricks-sql-go" 9 | 10 | "github.com/kndndrj/nvim-dbee/dbee/core" 11 | "github.com/kndndrj/nvim-dbee/dbee/core/builders" 12 | ) 13 | 14 | // Register client 15 | func init() { 16 | _ = register(&Databricks{}, "databricks") 17 | } 18 | 19 | var _ core.Adapter = (*Databricks)(nil) 20 | 21 | type Databricks struct{} 22 | 23 | // Connect parses the connectionURL and returns a new core.Driver 24 | // connectionURL is a DSN structure in the format of: 25 | // 26 | // token:[my_token]@[hostname]:[port]/[endpoint http path]?param=value 27 | // 28 | // requires the 'catalog' parameter to be set. 29 | 30 | // TODO: This could be extended with databricks connect by looking up 31 | // the config if connectionURL is empty. Added in the future 32 | 33 | // see https://github.com/databricks/databricks-sql-go for more information. 34 | func (d *Databricks) Connect(connectionURL string) (core.Driver, error) { 35 | parsedURL, err := url.Parse(connectionURL) 36 | if err != nil { 37 | return nil, fmt.Errorf("failed to parse connection string: %w: ", err) 38 | } 39 | 40 | // NOTE: we could add a PingContext with timeout here but I'll leave that 41 | // up to the user to add in the DSN URL (given databricks bootup time). 42 | db, err := sql.Open("databricks", parsedURL.String()) 43 | if err != nil { 44 | return nil, fmt.Errorf("invalid databricks connection string: %w", err) 45 | } 46 | 47 | currentCatalog := parsedURL.Query().Get("catalog") 48 | if currentCatalog == "" { 49 | return nil, fmt.Errorf("required parameter '?catalog=' is missing") 50 | } 51 | 52 | return &databricksDriver{ 53 | c: builders.NewClient(db), 54 | connectionURL: parsedURL, 55 | currentCatalog: currentCatalog, 56 | }, nil 57 | } 58 | 59 | // GetHelpers returns a map of helper queries for the given table. 60 | func (d *Databricks) GetHelpers(opts *core.TableOptions) map[string]string { 61 | // TODO: extend this to include more helper queries 62 | list := fmt.Sprintf("SELECT * FROM %s.%s LIMIT 100;", opts.Schema, opts.Table) 63 | columns := fmt.Sprintf(` 64 | SELECT * 65 | FROM information_schema.column 66 | WHERE table_schema = '%s' 67 | AND table_name = '%s';`, 68 | opts.Schema, opts.Table) 69 | describe := fmt.Sprintf("DESCRIBE EXTENDED %s.%s;", opts.Schema, opts.Table) 70 | constraints := fmt.Sprintf(` 71 | SELECT * 72 | FROM information_schema.table_constraints 73 | WHERE table_schema = '%s' 74 | AND table_name = '%s';`, 75 | opts.Schema, opts.Table) 76 | keys := fmt.Sprintf(` 77 | SELECT * 78 | FROM information_schema.key_column_usage 79 | WHERE table_schema = '%s' 80 | AND table_name = '%s';`, 81 | opts.Schema, opts.Table) 82 | return map[string]string{ 83 | "List": list, 84 | "Columns": columns, 85 | "Describe": describe, 86 | "Constraints": constraints, 87 | "Keys": keys, 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /dbee/adapters/databricks_driver.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "net/url" 8 | 9 | "github.com/kndndrj/nvim-dbee/dbee/core" 10 | "github.com/kndndrj/nvim-dbee/dbee/core/builders" 11 | ) 12 | 13 | var ( 14 | _ core.Driver = (*databricksDriver)(nil) 15 | _ core.DatabaseSwitcher = (*databricksDriver)(nil) 16 | ) 17 | 18 | // databricksDriver is a driver for Databricks. 19 | type databricksDriver struct { 20 | // c is the client used to execute queries. 21 | c *builders.Client 22 | connectionURL *url.URL 23 | currentCatalog string 24 | } 25 | 26 | // Query executes the given query and returns the result stream. 27 | func (d *databricksDriver) Query(ctx context.Context, query string) (core.ResultStream, error) { 28 | return d.c.QueryUntilNotEmpty(ctx, query) 29 | } 30 | 31 | // Columns returns the columns and their types for the given table. 32 | func (d *databricksDriver) Columns(opts *core.TableOptions) ([]*core.Column, error) { 33 | return d.c.ColumnsFromQuery(` 34 | SELECT column_name, data_type 35 | FROM information_schema.columns 36 | WHERE 37 | table_schema='%s' AND 38 | table_name='%s';`, 39 | opts.Schema, opts.Table) 40 | } 41 | 42 | // Structure returns the structure of the current catalog/database. 43 | func (d *databricksDriver) Structure() ([]*core.Structure, error) { 44 | catalogQuery := fmt.Sprintf(` 45 | SELECT table_schema, table_name, table_type 46 | FROM system.information_schema.tables 47 | WHERE table_catalog = '%s'; `, 48 | d.currentCatalog) 49 | 50 | rows, err := d.Query(context.Background(), catalogQuery) 51 | if err != nil { 52 | return nil, err 53 | } 54 | 55 | return core.GetGenericStructure(rows, getDatabricksStructureType) 56 | } 57 | 58 | // getDatabricksStructureType returns the core.StructureType based on the 59 | // given type string for databricks adapter. 60 | func getDatabricksStructureType(typ string) core.StructureType { 61 | switch typ { 62 | case "TABLE", "BASE TABLE", "SYSTEM TABLE", "MANAGED", "STREAMING_TABLE", "MANAGED_SHALLOW_CLONE", "MANAGED_DEEP_CLONE": 63 | return core.StructureTypeTable 64 | case "VIEW", "SYSTEM VIEW", "MATERIALIZED_VIEW": 65 | return core.StructureTypeView 66 | default: 67 | return core.StructureTypeNone 68 | } 69 | } 70 | 71 | // Close closes the connection to the database. 72 | func (d *databricksDriver) Close() { 73 | d.c.Close() 74 | } 75 | 76 | // ListDatabases returns the current catalog and a list of 77 | // available catalogs. 78 | func (d *databricksDriver) ListDatabases() (current string, available []string, err error) { 79 | query := `SHOW CATALOGS;` 80 | 81 | rows, err := d.Query(context.Background(), query) 82 | if err != nil { 83 | return "", nil, err 84 | } 85 | 86 | for rows.HasNext() { 87 | row, err := rows.Next() 88 | if err != nil { 89 | return "", nil, err 90 | } 91 | 92 | catalog, ok := row[0].(string) 93 | if !ok { 94 | return "", nil, fmt.Errorf("expected string, got %T", row[0]) 95 | } 96 | available = append(available, catalog) 97 | } 98 | 99 | return d.currentCatalog, available, nil 100 | } 101 | 102 | // SelectDatabase switches the current database/catalog to the selected one. 103 | func (d *databricksDriver) SelectDatabase(name string) error { 104 | // update the connection url with the new catalog param 105 | q := d.connectionURL.Query() 106 | q.Set("catalog", name) 107 | d.connectionURL.RawQuery = q.Encode() 108 | 109 | db, err := sql.Open("databricks", d.connectionURL.String()) 110 | if err != nil { 111 | return fmt.Errorf("error switching catalog: %w", err) 112 | } 113 | 114 | // update the current catalog 115 | d.currentCatalog = name 116 | d.c.Swap(db) 117 | 118 | return nil 119 | } 120 | -------------------------------------------------------------------------------- /dbee/adapters/databricks_test.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "testing" 5 | 6 | _ "github.com/databricks/databricks-sql-go" 7 | "github.com/kndndrj/nvim-dbee/dbee/core" 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestDatabricks_Connect(t *testing.T) { 13 | tests := []struct { 14 | name string 15 | connectionURL string 16 | wantErr bool 17 | messageErr string 18 | }{ 19 | { 20 | name: "should fail with invalid url format", 21 | connectionURL: "://invalid", 22 | wantErr: true, 23 | messageErr: "failed to parse connection string", 24 | }, 25 | { 26 | name: "should fail with missing catalog", 27 | connectionURL: "token:dummytoken@hostname:443/sql/1.0/endpoints/1234567890", 28 | wantErr: true, 29 | messageErr: "required parameter '?catalog=' is missing", 30 | }, 31 | { 32 | name: "should succeed with valid connection", 33 | connectionURL: "token:dummytoken@hostname:443/sql/1.0/endpoints/1234567890?catalog=my_catalog", 34 | }, 35 | } 36 | for _, tt := range tests { 37 | t.Run(tt.name, func(t *testing.T) { 38 | t.Parallel() 39 | 40 | d := &Databricks{} 41 | got, err := d.Connect(tt.connectionURL) 42 | 43 | if tt.wantErr { 44 | assert.NotEqual(t, "", tt.messageErr) 45 | assert.Error(t, err) 46 | assert.Contains(t, err.Error(), tt.messageErr) 47 | return 48 | } 49 | assert.NoError(t, err) 50 | assert.NotNil(t, got) 51 | }) 52 | } 53 | } 54 | 55 | func TestDatabricks_GetHelpers(t *testing.T) { 56 | defaultOpts := &core.TableOptions{ 57 | Schema: "test_schema", 58 | Table: "test_table", 59 | Materialization: core.StructureTypeTable, 60 | } 61 | tests := []struct { 62 | name string 63 | key string 64 | opts *core.TableOptions 65 | want string 66 | }{ 67 | { 68 | name: "should return list query", 69 | key: "List", 70 | opts: defaultOpts, 71 | want: "SELECT * FROM test_schema.test_table LIMIT 100;", 72 | }, 73 | { 74 | name: "should return columns query", 75 | key: "Columns", 76 | opts: defaultOpts, 77 | want: "\n\t\tSELECT *\n\t\tFROM information_schema.column\n\t\tWHERE table_schema = 'test_schema'\n\t\t\tAND table_name = 'test_table';", 78 | }, 79 | { 80 | name: "should return describe query", 81 | key: "Describe", 82 | opts: defaultOpts, 83 | want: "DESCRIBE EXTENDED test_schema.test_table;", 84 | }, 85 | { 86 | name: "should return constraints query", 87 | key: "Constraints", 88 | opts: defaultOpts, 89 | want: "\n\t\tSELECT *\n\t\tFROM information_schema.table_constraints\n\t\tWHERE table_schema = 'test_schema'\n\t\t\tAND table_name = 'test_table';", 90 | }, 91 | { 92 | name: "should return key_column_usage query", 93 | key: "Keys", 94 | opts: defaultOpts, 95 | want: "\n\t\tSELECT *\n\t\tFROM information_schema.key_column_usage\n\t\tWHERE table_schema = 'test_schema'\n\t\t\tAND table_name = 'test_table';", 96 | }, 97 | } 98 | 99 | d := &Databricks{} 100 | helpers := d.GetHelpers(defaultOpts) 101 | 102 | for helperKey := range helpers { 103 | var found bool 104 | for _, tt := range tests { 105 | if tt.key == helperKey { 106 | found = true 107 | break 108 | } 109 | } 110 | require.True(t, found, "missing test case for helper key: %q", helperKey) 111 | } 112 | for _, tt := range tests { 113 | t.Run(tt.name, func(t *testing.T) { 114 | t.Parallel() 115 | got := helpers[tt.key] 116 | assert.Equal(t, tt.want, got) 117 | }) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /dbee/adapters/duck.go: -------------------------------------------------------------------------------- 1 | //go:build cgo && ((darwin && (amd64 || arm64)) || (linux && (amd64 || arm64 || riscv64))) 2 | 3 | package adapters 4 | 5 | import ( 6 | "database/sql" 7 | "fmt" 8 | "path/filepath" 9 | "strings" 10 | 11 | _ "github.com/marcboeker/go-duckdb" 12 | 13 | "github.com/kndndrj/nvim-dbee/dbee/core" 14 | "github.com/kndndrj/nvim-dbee/dbee/core/builders" 15 | ) 16 | 17 | // Register client 18 | func init() { 19 | _ = register(&Duck{}, "duck", "duckdb") 20 | } 21 | 22 | var _ core.Adapter = (*Duck)(nil) 23 | 24 | type Duck struct{} 25 | 26 | // Helper function to get database from url 27 | func parseDatabaseFromPath(path string) string { 28 | base := filepath.Base(path) 29 | parts := strings.Split(base, ".") 30 | if len(parts) > 1 && parts[0] == "" { 31 | parts = parts[1:] 32 | } 33 | return parts[0] 34 | } 35 | 36 | func (d *Duck) Connect(url string) (core.Driver, error) { 37 | db, err := sql.Open("duckdb", url) 38 | if err != nil { 39 | return nil, fmt.Errorf("unable to connect to duckdb database: %v", err) 40 | } 41 | 42 | currentDB := "memory" 43 | if url != "" { 44 | currentDB = parseDatabaseFromPath(url) 45 | } 46 | 47 | return &duckDriver{ 48 | c: builders.NewClient(db), 49 | currentDB: currentDB, 50 | }, nil 51 | } 52 | 53 | func (*Duck) GetHelpers(opts *core.TableOptions) map[string]string { 54 | return map[string]string{ 55 | "List": fmt.Sprintf("SELECT * FROM %q LIMIT 500", opts.Table), 56 | "Columns": fmt.Sprintf("DESCRIBE %q", opts.Table), 57 | "Indexes": fmt.Sprintf("SELECT * FROM duckdb_indexes() WHERE table_name = '%s'", opts.Table), 58 | "Constraints": fmt.Sprintf("SELECT * FROM duckdb_constraints() WHERE table_name = '%s'", opts.Table), 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /dbee/adapters/duck_driver.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/kndndrj/nvim-dbee/dbee/core" 8 | "github.com/kndndrj/nvim-dbee/dbee/core/builders" 9 | ) 10 | 11 | var ( 12 | _ core.Driver = (*duckDriver)(nil) 13 | _ core.DatabaseSwitcher = (*duckDriver)(nil) 14 | ) 15 | 16 | type duckDriver struct { 17 | c *builders.Client 18 | currentDB string 19 | } 20 | 21 | func (d *duckDriver) Query(ctx context.Context, query string) (core.ResultStream, error) { 22 | return d.c.QueryUntilNotEmpty(ctx, query) 23 | } 24 | 25 | func (d *duckDriver) Columns(opts *core.TableOptions) ([]*core.Column, error) { 26 | return d.c.ColumnsFromQuery("DESCRIBE %q.%q", opts.Schema, opts.Table) 27 | } 28 | 29 | func (d *duckDriver) Structure() ([]*core.Structure, error) { 30 | catalogQuery := fmt.Sprintf(` 31 | SELECT table_schema, table_name, table_type 32 | FROM information_schema.tables 33 | WHERE table_catalog = '%s';`, 34 | d.currentDB) 35 | 36 | rows, err := d.Query(context.Background(), catalogQuery) 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | return core.GetGenericStructure(rows, getDuckDBStructureType) 42 | } 43 | 44 | // getDuckDBStructureType returns the core.StructureType based on the 45 | // given type string for duckdb adapter. 46 | func getDuckDBStructureType(typ string) core.StructureType { 47 | // TODO: (phdah) Add more types if exists 48 | switch typ { 49 | case "BASE TABLE": 50 | return core.StructureTypeTable 51 | case "VIEW": 52 | return core.StructureTypeView 53 | default: 54 | return core.StructureTypeNone 55 | } 56 | } 57 | 58 | // ListDatabases returns the current catalog and a list of available catalogs. 59 | // NOTE: (phdah) As of now, swapping catalogs is not enabled and only the 60 | // current will be shown 61 | func (d *duckDriver) ListDatabases() (current string, available []string, err error) { 62 | // no-op 63 | return d.currentDB, []string{"not supported yet"}, nil 64 | } 65 | 66 | // SelectDatabase switches the current database/catalog to the selected one. 67 | func (d *duckDriver) SelectDatabase(name string) error { 68 | return nil 69 | } 70 | 71 | // Close closes the connection to the database. 72 | func (d *duckDriver) Close() { 73 | d.c.Close() 74 | } 75 | -------------------------------------------------------------------------------- /dbee/adapters/duck_driver_test.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func Test_parseDatabaseFromPath(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | input string 13 | want string 14 | }{ 15 | { 16 | name: "should return `test` part from unix file path", 17 | input: "/tmp/test.db", 18 | want: "test", 19 | }, 20 | { 21 | name: "should return `.hiddenFile` part from unix file path", 22 | input: "/tmp/.hiddenFile.db", 23 | want: "hiddenFile", 24 | }, 25 | { 26 | name: "should return `my_file` part from file url path", 27 | input: "file:///tmp/my_file.database", 28 | want: "my_file", 29 | }, 30 | { 31 | name: "should return `my_db` part from s3 bucket url", 32 | input: "s3://bucket_name/path/to/my_db.duckdb", 33 | want: "my_db", 34 | }, 35 | { 36 | name: "should return `remote_db` part from https url", 37 | input: "https://www.example.com/remote_db.example.new", 38 | want: "remote_db", 39 | }, 40 | } 41 | for _, tt := range tests { 42 | t.Run(tt.name, func(t *testing.T) { 43 | got := parseDatabaseFromPath(tt.input) 44 | assert.Equal(t, tt.want, got) 45 | }) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /dbee/adapters/mongo.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "context" 5 | "encoding/gob" 6 | "fmt" 7 | "net/url" 8 | 9 | "go.mongodb.org/mongo-driver/bson" 10 | "go.mongodb.org/mongo-driver/bson/primitive" 11 | "go.mongodb.org/mongo-driver/mongo" 12 | "go.mongodb.org/mongo-driver/mongo/options" 13 | 14 | "github.com/kndndrj/nvim-dbee/dbee/core" 15 | ) 16 | 17 | // Register client 18 | func init() { 19 | _ = register(&Mongo{}, "mongo", "mongodb") 20 | 21 | // register known types with gob 22 | // full list available in go.mongodb.org/.../bson godoc 23 | gob.Register(&mongoResponse{}) 24 | gob.Register(bson.A{}) 25 | gob.Register(bson.M{}) 26 | gob.Register(bson.D{}) 27 | gob.Register(primitive.ObjectID{}) 28 | // gob.Register(primitive.DateTime) 29 | gob.Register(primitive.Binary{}) 30 | gob.Register(primitive.Regex{}) 31 | // gob.Register(primitive.JavaScript) 32 | gob.Register(primitive.CodeWithScope{}) 33 | gob.Register(primitive.Timestamp{}) 34 | gob.Register(primitive.Decimal128{}) 35 | // gob.Register(primitive.MinKey{}) 36 | // gob.Register(primitive.MaxKey{}) 37 | // gob.Register(primitive.Undefined{}) 38 | gob.Register(primitive.DBPointer{}) 39 | // gob.Register(primitive.Symbol) 40 | } 41 | 42 | var _ core.Adapter = (*Mongo)(nil) 43 | 44 | type Mongo struct{} 45 | 46 | func (m *Mongo) Connect(rawURL string) (core.Driver, error) { 47 | // get database name from url 48 | u, err := url.Parse(rawURL) 49 | if err != nil { 50 | return nil, fmt.Errorf("mongo: invalid url: %w", err) 51 | } 52 | 53 | opts := options.Client().ApplyURI(rawURL) 54 | client, err := mongo.Connect(context.TODO(), opts) 55 | if err != nil { 56 | return nil, err 57 | } 58 | 59 | return &mongoDriver{ 60 | c: client, 61 | dbName: u.Path[1:], 62 | }, nil 63 | } 64 | 65 | func (*Mongo) GetHelpers(opts *core.TableOptions) map[string]string { 66 | return map[string]string{ 67 | "List": fmt.Sprintf(`{"find": %q}`, opts.Table), 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /dbee/adapters/mysql.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "regexp" 7 | 8 | _ "github.com/go-sql-driver/mysql" 9 | 10 | "github.com/kndndrj/nvim-dbee/dbee/core" 11 | "github.com/kndndrj/nvim-dbee/dbee/core/builders" 12 | ) 13 | 14 | // Register client 15 | func init() { 16 | _ = register(&MySQL{}, "mysql") 17 | } 18 | 19 | var _ core.Adapter = (*MySQL)(nil) 20 | 21 | type MySQL struct{} 22 | 23 | func (m *MySQL) Connect(url string) (core.Driver, error) { 24 | // add multiple statements support parameter 25 | match, err := regexp.MatchString(`[\?][\w]+=[\w-]+`, url) 26 | if err != nil { 27 | return nil, err 28 | } 29 | sep := "?" 30 | if match { 31 | sep = "&" 32 | } 33 | 34 | db, err := sql.Open("mysql", url+sep+"multiStatements=true") 35 | if err != nil { 36 | return nil, fmt.Errorf("unable to connect to mysql database: %v", err) 37 | } 38 | 39 | return &mySQLDriver{ 40 | c: builders.NewClient(db), 41 | }, nil 42 | } 43 | 44 | func (*MySQL) GetHelpers(opts *core.TableOptions) map[string]string { 45 | return map[string]string{ 46 | "List": fmt.Sprintf("SELECT * FROM `%s`.`%s` LIMIT 500", opts.Schema, opts.Table), 47 | "Columns": fmt.Sprintf("DESCRIBE `%s`.`%s`", opts.Schema, opts.Table), 48 | "Indexes": fmt.Sprintf("SHOW INDEXES FROM `%s`.`%s`", opts.Schema, opts.Table), 49 | "Foreign Keys": fmt.Sprintf("SELECT * FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE TABLE_SCHEMA = '%s' AND TABLE_NAME = '%s' AND CONSTRAINT_TYPE = 'FOREIGN KEY'", opts.Schema, opts.Table), 50 | "Primary Keys": fmt.Sprintf("SELECT * FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE TABLE_SCHEMA = '%s' AND TABLE_NAME = '%s' AND CONSTRAINT_TYPE = 'PRIMARY KEY'", opts.Schema, opts.Table), 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /dbee/adapters/mysql_driver.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/kndndrj/nvim-dbee/dbee/core" 7 | "github.com/kndndrj/nvim-dbee/dbee/core/builders" 8 | ) 9 | 10 | var _ core.Driver = (*mySQLDriver)(nil) 11 | 12 | type mySQLDriver struct { 13 | c *builders.Client 14 | } 15 | 16 | func (c *mySQLDriver) Query(ctx context.Context, query string) (core.ResultStream, error) { 17 | // run query, fallback to affected rows 18 | return c.c.QueryUntilNotEmpty(ctx, query, "select ROW_COUNT() as 'Rows Affected'") 19 | } 20 | 21 | func (c *mySQLDriver) Columns(opts *core.TableOptions) ([]*core.Column, error) { 22 | return c.c.ColumnsFromQuery("DESCRIBE `%s`.`%s`", opts.Schema, opts.Table) 23 | } 24 | 25 | func (c *mySQLDriver) Structure() ([]*core.Structure, error) { 26 | query := `SELECT table_schema, table_name FROM information_schema.tables` 27 | 28 | rows, err := c.Query(context.TODO(), query) 29 | if err != nil { 30 | return nil, err 31 | } 32 | 33 | children := make(map[string][]*core.Structure) 34 | 35 | for rows.HasNext() { 36 | row, err := rows.Next() 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | // We know for a fact there are 2 string fields (see query above) 42 | schema := row[0].(string) 43 | table := row[1].(string) 44 | 45 | children[schema] = append(children[schema], &core.Structure{ 46 | Name: table, 47 | Schema: schema, 48 | Type: core.StructureTypeTable, 49 | }) 50 | 51 | } 52 | 53 | var structure []*core.Structure 54 | 55 | for k, v := range children { 56 | structure = append(structure, &core.Structure{ 57 | Name: k, 58 | Schema: k, 59 | Type: core.StructureTypeNone, 60 | Children: v, 61 | }) 62 | } 63 | 64 | return structure, nil 65 | } 66 | 67 | func (c *mySQLDriver) Close() { 68 | c.c.Close() 69 | } 70 | -------------------------------------------------------------------------------- /dbee/adapters/oracle.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | 7 | _ "github.com/sijms/go-ora/v2" 8 | 9 | "github.com/kndndrj/nvim-dbee/dbee/core" 10 | "github.com/kndndrj/nvim-dbee/dbee/core/builders" 11 | ) 12 | 13 | // Register client 14 | func init() { 15 | _ = register(&Oracle{}, "oracle") 16 | } 17 | 18 | var _ core.Adapter = (*Oracle)(nil) 19 | 20 | type Oracle struct{} 21 | 22 | func (o *Oracle) Connect(url string) (core.Driver, error) { 23 | db, err := sql.Open("oracle", url) 24 | if err != nil { 25 | return nil, fmt.Errorf("unable to connect to oracle database: %v", err) 26 | } 27 | 28 | return &oracleDriver{ 29 | c: builders.NewClient(db), 30 | }, nil 31 | } 32 | 33 | func (*Oracle) GetHelpers(opts *core.TableOptions) map[string]string { 34 | from := ` 35 | FROM all_constraints N 36 | JOIN all_cons_columns L 37 | ON N.constraint_name = L.constraint_name 38 | AND N.owner = L.owner ` 39 | 40 | qualifyAndOrderBy := func(by string) string { 41 | return fmt.Sprintf(` 42 | L.table_name = '%s' 43 | ORDER BY %s`, opts.Table, by) 44 | } 45 | 46 | keyCmd := func(constraint string) string { 47 | return fmt.Sprintf(` 48 | SELECT 49 | L.table_name, 50 | L.column_name 51 | %s 52 | WHERE 53 | N.constraint_type = '%s' AND %s`, 54 | 55 | from, 56 | constraint, 57 | qualifyAndOrderBy("L.column_name"), 58 | ) 59 | } 60 | 61 | return map[string]string{ 62 | "Columns": fmt.Sprintf(`SELECT col.column_id, 63 | col.owner AS schema_name, 64 | col.table_name, 65 | col.column_name, 66 | col.data_type, 67 | col.data_length, 68 | col.data_precision, 69 | col.data_scale, 70 | col.nullable 71 | FROM sys.all_tab_columns col 72 | INNER JOIN sys.all_tables t 73 | ON col.owner = t.owner 74 | AND col.table_name = t.table_name 75 | WHERE col.owner = '%s' 76 | AND col.table_name = '%s' 77 | ORDER BY col.owner, col.table_name, col.column_id `, 78 | 79 | opts.Schema, 80 | opts.Table, 81 | ), 82 | 83 | "Foreign Keys": keyCmd("R"), 84 | 85 | "Indexes": fmt.Sprintf(` 86 | SELECT DISTINCT 87 | N.owner, 88 | N.index_name, 89 | N.constraint_type 90 | %s 91 | WHERE %s `, 92 | 93 | from, 94 | qualifyAndOrderBy("N.index_name"), 95 | ), 96 | 97 | "List": fmt.Sprintf("SELECT * FROM %q.%q", opts.Schema, opts.Table), 98 | 99 | "Primary Keys": keyCmd("P"), 100 | 101 | "References": fmt.Sprintf(` 102 | SELECT 103 | RFRING.owner, 104 | RFRING.table_name, 105 | RFRING.column_name 106 | FROM all_cons_columns RFRING 107 | JOIN all_constraints N 108 | ON RFRING.constraint_name = N.constraint_name 109 | JOIN all_cons_columns RFRD 110 | ON N.r_constraint_name = RFRD.constraint_name 111 | JOIN all_users U 112 | ON N.owner = U.username 113 | WHERE 114 | N.constraint_type = 'R' 115 | AND 116 | U.common = 'NO' 117 | AND 118 | RFRD.owner = '%s' 119 | AND 120 | RFRD.table_name = '%s' 121 | ORDER BY 122 | RFRING.owner, 123 | RFRING.table_name, 124 | RFRING.column_name`, 125 | 126 | opts.Schema, 127 | opts.Table, 128 | ), 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /dbee/adapters/oracle_driver.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "github.com/kndndrj/nvim-dbee/dbee/core" 8 | "github.com/kndndrj/nvim-dbee/dbee/core/builders" 9 | ) 10 | 11 | var _ core.Driver = (*oracleDriver)(nil) 12 | 13 | type oracleDriver struct { 14 | c *builders.Client 15 | } 16 | 17 | func (d *oracleDriver) Query(ctx context.Context, query string) (core.ResultStream, error) { 18 | // Remove the trailing semicolon from the query - for some reason it isn't supported in go_ora 19 | query = strings.TrimSuffix(query, ";") 20 | 21 | // Use Exec or Query depending on the query 22 | action := strings.ToLower(strings.Split(query, " ")[0]) 23 | hasReturnValues := strings.Contains(strings.ToLower(query), " returning ") 24 | if (action == "update" || action == "delete" || action == "insert") && !hasReturnValues { 25 | return d.c.Exec(ctx, query) 26 | } 27 | 28 | return d.c.QueryUntilNotEmpty(ctx, query) 29 | } 30 | 31 | func (d *oracleDriver) Columns(opts *core.TableOptions) ([]*core.Column, error) { 32 | return d.c.ColumnsFromQuery(` 33 | SELECT 34 | col.column_name, 35 | col.data_type 36 | FROM sys.all_tab_columns col 37 | INNER JOIN sys.all_tables t 38 | ON col.owner = t.owner 39 | AND col.table_name = t.table_name 40 | WHERE col.owner = '%s' 41 | AND col.table_name = '%s' 42 | ORDER BY col.owner, col.table_name, col.column_id `, 43 | 44 | opts.Schema, 45 | opts.Table) 46 | } 47 | 48 | func (d *oracleDriver) Structure() ([]*core.Structure, error) { 49 | query := ` 50 | SELECT owner, object_name, type 51 | FROM ( 52 | SELECT owner, table_name as object_name, 'TABLE' as type 53 | FROM all_tables 54 | UNION ALL 55 | SELECT owner, table_name as object_name, 'EXTERNAL TABLE' as type 56 | FROM all_external_tables 57 | UNION ALL 58 | SELECT owner, view_name as object_name, 'VIEW' as type 59 | FROM all_views 60 | UNION ALL 61 | SELECT owner, mview_name as object_name, 'MATERIALIZED VIEW' as type 62 | FROM all_mviews 63 | ) 64 | WHERE owner IN (SELECT username FROM all_users WHERE common = 'NO') 65 | ORDER BY owner, object_name 66 | ` 67 | 68 | rows, err := d.Query(context.TODO(), query) 69 | if err != nil { 70 | return nil, err 71 | } 72 | 73 | decodeStructureType := func(s string) core.StructureType { 74 | switch s { 75 | case "TABLE", "EXTERNAL TABLE": 76 | return core.StructureTypeTable 77 | case "VIEW": 78 | return core.StructureTypeView 79 | case "MATERIALIZED VIEW": 80 | return core.StructureTypeMaterializedView 81 | default: 82 | return core.StructureTypeNone 83 | } 84 | } 85 | 86 | return core.GetGenericStructure(rows, decodeStructureType) 87 | } 88 | 89 | func (d *oracleDriver) Close() { d.c.Close() } 90 | -------------------------------------------------------------------------------- /dbee/adapters/postgres.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "database/sql" 5 | "encoding/gob" 6 | "fmt" 7 | nurl "net/url" 8 | 9 | _ "github.com/lib/pq" 10 | 11 | "github.com/kndndrj/nvim-dbee/dbee/core" 12 | "github.com/kndndrj/nvim-dbee/dbee/core/builders" 13 | ) 14 | 15 | // Register client 16 | func init() { 17 | _ = register(&Postgres{}, "postgres", "postgresql", "pg") 18 | 19 | // register special json response with gob 20 | gob.Register(&postgresJSONResponse{}) 21 | } 22 | 23 | var _ core.Adapter = (*Postgres)(nil) 24 | 25 | type Postgres struct{} 26 | 27 | func (p *Postgres) Connect(url string) (core.Driver, error) { 28 | u, err := nurl.Parse(url) 29 | if err != nil { 30 | return nil, fmt.Errorf("could not parse db connection string: %w: ", err) 31 | } 32 | 33 | db, err := sql.Open("postgres", u.String()) 34 | if err != nil { 35 | return nil, fmt.Errorf("unable to connect to postgres database: %w", err) 36 | } 37 | 38 | jsonProcessor := func(a any) any { 39 | b, ok := a.([]byte) 40 | if !ok { 41 | return a 42 | } 43 | 44 | return newPostgresJSONResponse(b) 45 | } 46 | 47 | return &postgresDriver{ 48 | c: builders.NewClient(db, 49 | builders.WithCustomTypeProcessor("json", jsonProcessor), 50 | builders.WithCustomTypeProcessor("jsonb", jsonProcessor), 51 | ), 52 | url: u, 53 | }, nil 54 | } 55 | 56 | func (*Postgres) GetHelpers(opts *core.TableOptions) map[string]string { 57 | basicConstraintQuery := ` 58 | SELECT tc.constraint_name, tc.table_name, kcu.column_name, ccu.table_name AS foreign_table_name, ccu.column_name AS foreign_column_name, rc.update_rule, rc.delete_rule 59 | FROM 60 | information_schema.table_constraints AS tc 61 | JOIN information_schema.key_column_usage AS kcu 62 | ON tc.constraint_name = kcu.constraint_name 63 | JOIN information_schema.referential_constraints as rc 64 | ON tc.constraint_name = rc.constraint_name 65 | JOIN information_schema.constraint_column_usage AS ccu 66 | ON ccu.constraint_name = tc.constraint_name 67 | ` 68 | 69 | return map[string]string{ 70 | "List": fmt.Sprintf("SELECT * FROM %q.%q LIMIT 500", opts.Schema, opts.Table), 71 | "Columns": fmt.Sprintf("SELECT * FROM information_schema.columns WHERE table_name='%s' AND table_schema='%s'", opts.Table, opts.Schema), 72 | "Indexes": fmt.Sprintf("SELECT * FROM pg_indexes WHERE tablename='%s' AND schemaname='%s'", opts.Table, opts.Schema), 73 | "Foreign Keys": fmt.Sprintf("%s WHERE constraint_type = 'FOREIGN KEY' AND tc.table_name = '%s' AND tc.table_schema = '%s'", 74 | basicConstraintQuery, 75 | opts.Table, 76 | opts.Schema, 77 | ), 78 | "References": fmt.Sprintf("%s WHERE constraint_type = 'FOREIGN KEY' AND ccu.table_name = '%s' AND tc.table_schema = '%s'", 79 | basicConstraintQuery, 80 | opts.Table, 81 | opts.Schema, 82 | ), 83 | "Primary Keys": fmt.Sprintf("%s WHERE constraint_type = 'PRIMARY KEY' AND tc.table_name = '%s' AND tc.table_schema = '%s'", 84 | basicConstraintQuery, 85 | opts.Table, 86 | opts.Schema, 87 | ), 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /dbee/adapters/redis.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "encoding/gob" 5 | "fmt" 6 | 7 | "github.com/redis/go-redis/v9" 8 | 9 | "github.com/kndndrj/nvim-dbee/dbee/core" 10 | ) 11 | 12 | // Register client 13 | func init() { 14 | _ = register(&Redis{}, "redis") 15 | 16 | // register known types with gob 17 | gob.Register(&redisResponse{}) 18 | gob.Register([]any{}) 19 | gob.Register(map[any]any{}) 20 | } 21 | 22 | var _ core.Adapter = (*Redis)(nil) 23 | 24 | type Redis struct{} 25 | 26 | func (r *Redis) Connect(url string) (core.Driver, error) { 27 | opt, err := redis.ParseURL(url) 28 | if err != nil { 29 | return nil, fmt.Errorf("unable to connect to redis database: %v", err) 30 | } 31 | c := redis.NewClient(opt) 32 | 33 | return &redisDriver{ 34 | redis: c, 35 | }, nil 36 | } 37 | 38 | func (*Redis) GetHelpers(opts *core.TableOptions) map[string]string { 39 | return map[string]string{ 40 | "List": "KEYS *", 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /dbee/adapters/redis_test.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | func TestParseRedisCmd(t *testing.T) { 10 | r := require.New(t) 11 | 12 | type testCase struct { 13 | unparsed string 14 | expectedResult []any 15 | expectedError error 16 | } 17 | 18 | testCases := []testCase{ 19 | // these should work 20 | { 21 | unparsed: `set key val`, 22 | expectedResult: []any{"set", "key", "val"}, 23 | expectedError: nil, 24 | }, 25 | { 26 | unparsed: `set key "double quoted val"`, 27 | expectedResult: []any{"set", "key", "double quoted val"}, 28 | expectedError: nil, 29 | }, 30 | { 31 | unparsed: `set key 'single quoted val'`, 32 | expectedResult: []any{"set", "key", "single quoted val"}, 33 | expectedError: nil, 34 | }, 35 | { 36 | unparsed: `set key 'single quoted val with nested unescaped double quote (")'`, 37 | expectedResult: []any{"set", "key", "single quoted val with nested unescaped double quote (\")"}, 38 | expectedError: nil, 39 | }, 40 | { 41 | unparsed: `set key 'single quoted val with nested escaped double quote (\")'`, 42 | expectedResult: []any{"set", "key", "single quoted val with nested escaped double quote (\")"}, 43 | expectedError: nil, 44 | }, 45 | { 46 | unparsed: `set key 'single quoted val with nested escaped single quote (\')'`, 47 | expectedResult: []any{"set", "key", "single quoted val with nested escaped single quote (')"}, 48 | expectedError: nil, 49 | }, 50 | { 51 | unparsed: `set key "double quoted val with nested unescaped single quote (')"`, 52 | expectedResult: []any{"set", "key", "double quoted val with nested unescaped single quote (')"}, 53 | expectedError: nil, 54 | }, 55 | { 56 | unparsed: `set key "double quoted val with nested escaped single quote (\')"`, 57 | expectedResult: []any{"set", "key", "double quoted val with nested escaped single quote (')"}, 58 | expectedError: nil, 59 | }, 60 | { 61 | unparsed: `set key "double quoted val with nested escaped double quote (\")"`, 62 | expectedResult: []any{"set", "key", "double quoted val with nested escaped double quote (\")"}, 63 | expectedError: nil, 64 | }, 65 | 66 | // these shouldn't work 67 | { 68 | unparsed: `set key "unmatched double quoted val`, 69 | expectedResult: nil, 70 | expectedError: ErrUnmatchedDoubleQuote(9), 71 | }, 72 | { 73 | unparsed: `set key 'unmatched single quoted val`, 74 | expectedResult: nil, 75 | expectedError: ErrUnmatchedSingleQuote(9), 76 | }, 77 | { 78 | unparsed: `set key "double quoted val with nested unescaped double quote (")"`, 79 | expectedResult: nil, 80 | expectedError: ErrUnmatchedDoubleQuote(64), 81 | }, 82 | { 83 | unparsed: `set key 'single quoted val with nested unescaped single quote (')'`, 84 | expectedResult: nil, 85 | expectedError: ErrUnmatchedSingleQuote(64), 86 | }, 87 | } 88 | 89 | for _, tc := range testCases { 90 | parsed, err := parseRedisCmd(tc.unparsed) 91 | if err != nil { 92 | r.Equal(err.Error(), tc.expectedError.Error()) 93 | continue 94 | } 95 | r.Equal(parsed, tc.expectedResult) 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /dbee/adapters/redshift.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "net/url" 8 | "time" 9 | 10 | "github.com/kndndrj/nvim-dbee/dbee/core" 11 | "github.com/kndndrj/nvim-dbee/dbee/core/builders" 12 | ) 13 | 14 | // init registers the RedshiftClient to the store, 15 | // i.e. to lua frontend. 16 | func init() { 17 | _ = register(&Redshift{}, "redshift") 18 | } 19 | 20 | var _ core.Adapter = (*Redshift)(nil) 21 | 22 | type Redshift struct{} 23 | 24 | func (r *Redshift) Connect(rawURL string) (core.Driver, error) { 25 | connURL, err := url.Parse(rawURL) 26 | if err != nil { 27 | return nil, fmt.Errorf("failed to parse connection string: %w", err) 28 | } 29 | 30 | // TODO: perhaps better to use something else than postgres driver.. 31 | db, err := sql.Open("postgres", connURL.String()) 32 | if err != nil { 33 | return nil, fmt.Errorf("unable to connect to redshift: %w", err) 34 | } 35 | 36 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 37 | defer cancel() 38 | if err := db.PingContext(ctx); err != nil { 39 | return nil, fmt.Errorf("unable to ping redshift: %w", err) 40 | } 41 | 42 | return &redshiftDriver{ 43 | c: builders.NewClient(db), 44 | connectionURL: connURL, 45 | }, nil 46 | } 47 | 48 | func (r *Redshift) GetHelpers(opts *core.TableOptions) map[string]string { 49 | out := make(map[string]string, 0) 50 | list := fmt.Sprintf("SELECT * FROM %q.%q LIMIT 100;", opts.Schema, opts.Table) 51 | 52 | switch opts.Materialization { 53 | case core.StructureTypeTable: 54 | out = map[string]string{ 55 | "List": list, 56 | "Columns": fmt.Sprintf("SELECT * FROM information_schema.columns WHERE table_name='%s' AND table_schema='%s';", opts.Table, opts.Schema), 57 | "Indexes": fmt.Sprintf("SELECT * FROM pg_indexes WHERE tablename='%s' AND schemaname='%s';", opts.Table, opts.Schema), 58 | "Foreign Keys": fmt.Sprintf(` 59 | SELECT tc.constraint_name, tc.table_name, kcu.column_name, ccu.table_name AS foreign_table_name, ccu.column_name AS foreign_column_name, rc.update_rule, rc.delete_rule 60 | FROM 61 | information_schema.table_constraints AS tc 62 | JOIN information_schema.key_column_usage AS kcu 63 | ON tc.constraint_name = kcu.constraint_name 64 | JOIN information_schema.referential_constraints as rc 65 | ON tc.constraint_name = rc.constraint_name 66 | JOIN information_schema.constraint_column_usage AS ccu 67 | ON ccu.constraint_name = tc.constraint_name 68 | WHERE constraint_type = 'FOREIGN KEY' AND tc.table_name = '%s' AND tc.table_schema = '%s';`, 69 | 70 | opts.Table, 71 | opts.Schema, 72 | ), 73 | "Table Definition": fmt.Sprintf(` 74 | SELECT 75 | * 76 | FROM svv_table_info 77 | WHERE "schema" = '%s' 78 | AND "table" = '%s';`, 79 | 80 | opts.Schema, 81 | opts.Table, 82 | ), 83 | } 84 | 85 | case core.StructureTypeView: 86 | out = map[string]string{ 87 | "List": list, 88 | "View Definition": fmt.Sprintf(` 89 | SELECT 90 | * 91 | FROM pg_views 92 | WHERE schemaname = '%s' 93 | AND viewname = '%s';`, 94 | 95 | opts.Schema, 96 | opts.Table, 97 | ), 98 | } 99 | } 100 | 101 | return out 102 | } 103 | -------------------------------------------------------------------------------- /dbee/adapters/redshift_driver.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "net/url" 8 | "time" 9 | 10 | _ "github.com/lib/pq" 11 | 12 | "github.com/kndndrj/nvim-dbee/dbee/core" 13 | "github.com/kndndrj/nvim-dbee/dbee/core/builders" 14 | ) 15 | 16 | var ( 17 | _ core.Driver = (*redshiftDriver)(nil) 18 | _ core.DatabaseSwitcher = (*redshiftDriver)(nil) 19 | ) 20 | 21 | // redshiftDriver is a sql client for redshiftDriver. 22 | // Mainly uses the postgres driver under the hood but with 23 | // custom Layout function to get the table and view names correctly. 24 | type redshiftDriver struct { 25 | c *builders.Client 26 | connectionURL *url.URL 27 | } 28 | 29 | // Query executes a query and returns the result as an IterResult. 30 | func (r *redshiftDriver) Query(ctx context.Context, query string) (core.ResultStream, error) { 31 | return r.c.QueryUntilNotEmpty(ctx, query) 32 | } 33 | 34 | // Close closes the underlying sql.DB connection. 35 | func (r *redshiftDriver) Close() { 36 | r.c.Close() 37 | } 38 | 39 | func (r *redshiftDriver) Columns(opts *core.TableOptions) ([]*core.Column, error) { 40 | return r.c.ColumnsFromQuery(` 41 | SELECT column_name, data_type 42 | FROM information_schema.columns 43 | WHERE 44 | table_schema='%s' AND 45 | table_name='%s' 46 | `, opts.Schema, opts.Table) 47 | } 48 | 49 | // Structure returns the layout of the database. This represents the 50 | // "schema" with all the tables and views. Note that ordering is not 51 | // done here. The ordering is done in the lua frontend. 52 | func (r *redshiftDriver) Structure() ([]*core.Structure, error) { 53 | query := ` 54 | SELECT 55 | trim(n.nspname) AS schema_name, 56 | trim(c.relname) AS table_name, 57 | CASE 58 | WHEN c.relkind = 'v' THEN 'VIEW' 59 | ELSE 'TABLE' 60 | END AS table_type 61 | FROM 62 | pg_class AS c 63 | INNER JOIN 64 | pg_namespace AS n ON c.relnamespace = n.oid 65 | WHERE 66 | n.nspname NOT IN ('information_schema', 'pg_catalog'); 67 | ` 68 | 69 | rows, err := r.Query(context.Background(), query) 70 | if err != nil { 71 | return nil, err 72 | } 73 | 74 | return core.GetGenericStructure(rows, getPGStructureType) 75 | } 76 | 77 | func (r *redshiftDriver) ListDatabases() (current string, available []string, err error) { 78 | query := ` 79 | SELECT current_database() AS current, datname 80 | FROM pg_database 81 | WHERE datistemplate = false 82 | AND datname != current_database();` 83 | 84 | rows, err := r.Query(context.Background(), query) 85 | if err != nil { 86 | return "", nil, err 87 | } 88 | 89 | for rows.HasNext() { 90 | row, err := rows.Next() 91 | if err != nil { 92 | return "", nil, err 93 | } 94 | 95 | // current database is the first column, available databases are the rest 96 | current = row[0].(string) 97 | available = append(available, row[1].(string)) 98 | } 99 | 100 | return current, available, nil 101 | } 102 | 103 | func (r *redshiftDriver) SelectDatabase(name string) error { 104 | r.connectionURL.Path = fmt.Sprintf("/%s", name) 105 | db, err := sql.Open("postgres", r.connectionURL.String()) 106 | if err != nil { 107 | return fmt.Errorf("unable to switch databases: %w", err) 108 | } 109 | 110 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 111 | defer cancel() 112 | if err = db.PingContext(ctx); err != nil { 113 | return fmt.Errorf("unable to ping redshift: %w", err) 114 | } 115 | 116 | r.c.Swap(db) 117 | return nil 118 | } 119 | -------------------------------------------------------------------------------- /dbee/adapters/sqlite.go: -------------------------------------------------------------------------------- 1 | //go:build (darwin && (amd64 || arm64)) || (freebsd && (386 || amd64 || arm || arm64)) || (linux && (386 || amd64 || arm || arm64 || ppc64le || riscv64 || s390x)) || (netbsd && amd64) || (openbsd && (amd64 || arm64)) || (windows && (amd64 || arm64)) 2 | 3 | package adapters 4 | 5 | import ( 6 | "database/sql" 7 | "fmt" 8 | "os/user" 9 | "path/filepath" 10 | "strings" 11 | 12 | _ "modernc.org/sqlite" 13 | 14 | "github.com/kndndrj/nvim-dbee/dbee/core" 15 | "github.com/kndndrj/nvim-dbee/dbee/core/builders" 16 | ) 17 | 18 | // Register client 19 | func init() { 20 | _ = register(&SQLite{}, "sqlite", "sqlite3") 21 | } 22 | 23 | var _ core.Adapter = (*SQLite)(nil) 24 | 25 | type SQLite struct{} 26 | 27 | func (s *SQLite) expandPath(path string) (string, error) { 28 | usr, err := user.Current() 29 | if err != nil { 30 | return "", fmt.Errorf("user.Current: %w", err) 31 | } 32 | 33 | if path == "~" { 34 | return usr.HomeDir, nil 35 | } else if strings.HasPrefix(path, "~/") { 36 | return filepath.Join(usr.HomeDir, path[2:]), nil 37 | } 38 | 39 | return path, nil 40 | } 41 | 42 | func (s *SQLite) Connect(url string) (core.Driver, error) { 43 | path, err := s.expandPath(url) 44 | if err != nil { 45 | return nil, err 46 | } 47 | 48 | db, err := sql.Open("sqlite", path) 49 | if err != nil { 50 | return nil, fmt.Errorf("unable to connect to sqlite database: %v", err) 51 | } 52 | 53 | return &sqliteDriver{ 54 | c: builders.NewClient(db), 55 | currentDatabase: path, 56 | }, nil 57 | } 58 | 59 | func (*SQLite) GetHelpers(opts *core.TableOptions) map[string]string { 60 | return map[string]string{ 61 | "List": fmt.Sprintf("SELECT * FROM %q LIMIT 500", opts.Table), 62 | "Columns": fmt.Sprintf("PRAGMA table_info('%s')", opts.Table), 63 | "Indexes": fmt.Sprintf("SELECT * FROM pragma_index_list('%s')", opts.Table), 64 | "Foreign Keys": fmt.Sprintf("SELECT * FROM pragma_foreign_key_list('%s')", opts.Table), 65 | "Primary Keys": fmt.Sprintf("SELECT * FROM pragma_index_list('%s') WHERE origin = 'pk'", opts.Table), 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /dbee/adapters/sqlite_driver.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/kndndrj/nvim-dbee/dbee/core" 7 | "github.com/kndndrj/nvim-dbee/dbee/core/builders" 8 | ) 9 | 10 | var ( 11 | _ core.Driver = (*sqliteDriver)(nil) 12 | _ core.DatabaseSwitcher = (*sqliteDriver)(nil) 13 | ) 14 | 15 | type sqliteDriver struct { 16 | c *builders.Client 17 | currentDatabase string 18 | } 19 | 20 | func (d *sqliteDriver) Query(ctx context.Context, query string) (core.ResultStream, error) { 21 | // run query, fallback to affected rows 22 | return d.c.QueryUntilNotEmpty(ctx, query, "select changes() as 'Rows Affected'") 23 | } 24 | 25 | func (d *sqliteDriver) Columns(opts *core.TableOptions) ([]*core.Column, error) { 26 | return d.c.ColumnsFromQuery("SELECT name, type FROM pragma_table_info('%s')", opts.Table) 27 | } 28 | 29 | func (d *sqliteDriver) Structure() ([]*core.Structure, error) { 30 | // sqlite is single schema structure, so we hardcode the name of it. 31 | query := "SELECT 'sqlite_schema' as schema, name, type FROM sqlite_schema" 32 | 33 | rows, err := d.Query(context.Background(), query) 34 | if err != nil { 35 | return nil, err 36 | } 37 | 38 | decodeStructureType := func(typ string) core.StructureType { 39 | switch typ { 40 | case "table": 41 | return core.StructureTypeTable 42 | case "view": 43 | return core.StructureTypeView 44 | default: 45 | return core.StructureTypeNone 46 | } 47 | } 48 | return core.GetGenericStructure(rows, decodeStructureType) 49 | } 50 | 51 | func (d *sqliteDriver) Close() { d.c.Close() } 52 | 53 | func (d *sqliteDriver) ListDatabases() (string, []string, error) { 54 | return d.currentDatabase, []string{"not supported yet"}, nil 55 | } 56 | 57 | // SelectDatabase is a no-op, added to make the UI more pleasent. 58 | func (d *sqliteDriver) SelectDatabase(name string) error { return nil } 59 | -------------------------------------------------------------------------------- /dbee/adapters/sqlserver_driver.go: -------------------------------------------------------------------------------- 1 | package adapters 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | nurl "net/url" 8 | "time" 9 | 10 | "github.com/kndndrj/nvim-dbee/dbee/core" 11 | "github.com/kndndrj/nvim-dbee/dbee/core/builders" 12 | ) 13 | 14 | var ( 15 | _ core.Driver = (*sqlServerDriver)(nil) 16 | _ core.DatabaseSwitcher = (*sqlServerDriver)(nil) 17 | ) 18 | 19 | type sqlServerDriver struct { 20 | c *builders.Client 21 | url *nurl.URL 22 | } 23 | 24 | func (c *sqlServerDriver) Query(ctx context.Context, query string) (core.ResultStream, error) { 25 | // run query, fallback to affected rows 26 | return c.c.QueryUntilNotEmpty(ctx, query, "select @@ROWCOUNT as 'Rows Affected'") 27 | } 28 | 29 | func (c *sqlServerDriver) Columns(opts *core.TableOptions) ([]*core.Column, error) { 30 | return c.c.ColumnsFromQuery(` 31 | SELECT 32 | column_name, 33 | data_type 34 | FROM information_schema.columns 35 | WHERE table_name='%s' AND 36 | table_schema = '%s'`, 37 | opts.Table, 38 | opts.Schema, 39 | ) 40 | } 41 | 42 | func (c *sqlServerDriver) Structure() ([]*core.Structure, error) { 43 | query := ` 44 | SELECT table_schema, table_name, table_type 45 | FROM INFORMATION_SCHEMA.TABLES` 46 | 47 | rows, err := c.Query(context.TODO(), query) 48 | if err != nil { 49 | return nil, err 50 | } 51 | 52 | return core.GetGenericStructure(rows, getPGStructureType) 53 | } 54 | 55 | func (c *sqlServerDriver) Close() { 56 | c.c.Close() 57 | } 58 | 59 | func (c *sqlServerDriver) ListDatabases() (current string, available []string, err error) { 60 | query := ` 61 | SELECT DB_NAME(), name 62 | FROM sys.databases 63 | WHERE name != DB_NAME(); 64 | ` 65 | 66 | rows, err := c.Query(context.TODO(), query) 67 | if err != nil { 68 | return "", nil, err 69 | } 70 | 71 | for rows.HasNext() { 72 | row, err := rows.Next() 73 | if err != nil { 74 | return "", nil, err 75 | } 76 | 77 | // We know for a fact there are 2 string fields (see query above) 78 | current = row[0].(string) 79 | available = append(available, row[1].(string)) 80 | } 81 | 82 | return current, available, nil 83 | } 84 | 85 | func (c *sqlServerDriver) SelectDatabase(name string) error { 86 | q := c.url.Query() 87 | q.Set("database", name) 88 | c.url.RawQuery = q.Encode() 89 | 90 | db, err := sql.Open("sqlserver", c.url.String()) 91 | if err != nil { 92 | return fmt.Errorf("unable to switch databases: %w", err) 93 | } 94 | 95 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 96 | defer cancel() 97 | 98 | if err := db.PingContext(ctx); err != nil { 99 | return fmt.Errorf("unable to switch databases: %w", err) 100 | } 101 | 102 | c.c.Swap(db) 103 | 104 | return nil 105 | } 106 | -------------------------------------------------------------------------------- /dbee/core/builders/client_options.go: -------------------------------------------------------------------------------- 1 | package builders 2 | 3 | import "strings" 4 | 5 | type clientConfig struct { 6 | typeProcessors map[string]func(any) any 7 | } 8 | 9 | type ClientOption func(*clientConfig) 10 | 11 | func WithCustomTypeProcessor(typ string, fn func(any) any) ClientOption { 12 | return func(cc *clientConfig) { 13 | t := strings.ToLower(typ) 14 | _, ok := cc.typeProcessors[t] 15 | if ok { 16 | // processor already registered for this type 17 | return 18 | } 19 | 20 | cc.typeProcessors[t] = fn 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /dbee/core/builders/columns.go: -------------------------------------------------------------------------------- 1 | package builders 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "github.com/kndndrj/nvim-dbee/dbee/core" 8 | ) 9 | 10 | // ColumnsFromResultStream converts the result stream to columns. 11 | // A result stream should return rows that are at least 2 columns wide and 12 | // have the following structure: 13 | // 14 | // 1st elem: name - string 15 | // 2nd elem: type - string 16 | func ColumnsFromResultStream(rows core.ResultStream) ([]*core.Column, error) { 17 | var out []*core.Column 18 | 19 | for rows.HasNext() { 20 | row, err := rows.Next() 21 | if err != nil { 22 | return nil, fmt.Errorf("result.Next: %w", err) 23 | } 24 | 25 | if len(row) < 2 { 26 | return nil, errors.New("could not retrieve column info: insufficient data") 27 | } 28 | 29 | name, ok := row[0].(string) 30 | if !ok { 31 | return nil, errors.New("could not retrieve column info: name not a string") 32 | } 33 | 34 | typ, ok := row[1].(string) 35 | if !ok { 36 | return nil, errors.New("could not retrieve column info: type not a string") 37 | } 38 | 39 | column := &core.Column{ 40 | Name: name, 41 | Type: typ, 42 | } 43 | 44 | out = append(out, column) 45 | } 46 | 47 | return out, nil 48 | } 49 | -------------------------------------------------------------------------------- /dbee/core/builders/next.go: -------------------------------------------------------------------------------- 1 | package builders 2 | 3 | import ( 4 | "errors" 5 | "sync/atomic" 6 | "time" 7 | 8 | "github.com/kndndrj/nvim-dbee/dbee/core" 9 | ) 10 | 11 | // NextSingle creates next and hasNext functions from a provided single value 12 | func NextSingle(value any) (func() (core.Row, error), func() bool) { 13 | has := true 14 | 15 | // iterator functions 16 | next := func() (core.Row, error) { 17 | if !has { 18 | return nil, errors.New("no next row") 19 | } 20 | has = false 21 | return core.Row{value}, nil 22 | } 23 | 24 | hasNext := func() bool { 25 | return has 26 | } 27 | 28 | return next, hasNext 29 | } 30 | 31 | // NextSlice creates next and hasNext functions from provided values 32 | // preprocessor is an optional function which parses a single value from slice before adding it to a row 33 | func NextSlice[T any](values []T, preprocess func(T) any) (func() (core.Row, error), func() bool) { 34 | if preprocess == nil { 35 | preprocess = func(v T) any { return v } 36 | } 37 | 38 | index := 0 39 | 40 | hasNext := func() bool { 41 | return index < len(values) 42 | } 43 | 44 | // iterator functions 45 | next := func() (core.Row, error) { 46 | if !hasNext() { 47 | return nil, errors.New("no next row") 48 | } 49 | 50 | row := core.Row{preprocess(values[index])} 51 | index++ 52 | return row, nil 53 | } 54 | 55 | return next, hasNext 56 | } 57 | 58 | // NextNil creates next and hasNext functions that don't return anything (no rows) 59 | func NextNil() (func() (core.Row, error), func() bool) { 60 | hasNext := func() bool { 61 | return false 62 | } 63 | 64 | // iterator functions 65 | next := func() (core.Row, error) { 66 | return nil, errors.New("no next row") 67 | } 68 | 69 | return next, hasNext 70 | } 71 | 72 | // closeOnce closes the channel if it isn't already closed. 73 | func closeOnce[T any](ch chan T) { 74 | select { 75 | case <-ch: 76 | default: 77 | close(ch) 78 | } 79 | } 80 | 81 | // NextYield creates next and hasNext functions by calling yield in internal function. 82 | // WARNING: the caller must call "hasNext" before each call to "next". 83 | func NextYield(fn func(yield func(...any)) error) (func() (core.Row, error), func() bool) { 84 | resultsCh := make(chan []any, 10) 85 | errorsCh := make(chan error, 1) 86 | readyCh := make(chan struct{}) 87 | doneCh := make(chan struct{}) 88 | 89 | // spawn channel function 90 | go func() { 91 | defer func() { 92 | close(doneCh) 93 | closeOnce(readyCh) 94 | close(resultsCh) 95 | close(errorsCh) 96 | }() 97 | 98 | err := fn(func(v ...any) { 99 | resultsCh <- v 100 | closeOnce(readyCh) 101 | }) 102 | if err != nil { 103 | errorsCh <- err 104 | } 105 | }() 106 | 107 | <-readyCh 108 | 109 | var nextVal atomic.Value 110 | var nextErr atomic.Value 111 | 112 | var hasNext func() bool 113 | hasNext = func() bool { 114 | select { 115 | case vals, ok := <-resultsCh: 116 | if !ok { 117 | return false 118 | } 119 | nextVal.Store(vals) 120 | return true 121 | case err := <-errorsCh: 122 | if err != nil { 123 | nextErr.Store(err) 124 | return false 125 | } 126 | case <-doneCh: 127 | if len(resultsCh) < 1 { 128 | return false 129 | } 130 | case <-time.After(5 * time.Second): 131 | nextErr.Store(errors.New("next row timeout")) 132 | return false 133 | } 134 | 135 | return hasNext() 136 | } 137 | 138 | next := func() (core.Row, error) { 139 | var val core.Row 140 | var err error 141 | 142 | nval := nextVal.Load() 143 | if nval != nil { 144 | val = nval.([]any) 145 | } 146 | nerr := nextErr.Load() 147 | if nerr != nil { 148 | err = nerr.(error) 149 | } 150 | return val, err 151 | } 152 | 153 | return next, hasNext 154 | } 155 | -------------------------------------------------------------------------------- /dbee/core/builders/next_test.go: -------------------------------------------------------------------------------- 1 | package builders_test 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | "time" 7 | 8 | "github.com/kndndrj/nvim-dbee/dbee/core" 9 | "github.com/kndndrj/nvim-dbee/dbee/core/builders" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func testNextYield(t *testing.T, sleep bool) { 14 | r := require.New(t) 15 | 16 | rows := [][]any{{"first", "row"}, {"second"}, {"third"}, {"fourth"}, {"fifth"}, {"and", "last", "row"}} 17 | 18 | next, hasNext := builders.NextYield(func(yield func(...any)) error { 19 | for i, row := range rows { 20 | if sleep && (i == 2 || i == 4) { 21 | time.Sleep(500 * time.Millisecond) 22 | } 23 | yield(row...) 24 | } 25 | 26 | return nil 27 | }) 28 | 29 | i := 0 30 | for hasNext() { 31 | row, err := next() 32 | 33 | r.NoError(err) 34 | 35 | r.NotEqual(0, len(row)) 36 | 37 | r.Equal(row, core.Row(rows[i])) 38 | 39 | i++ 40 | } 41 | 42 | r.Equal(i, len(rows)) 43 | } 44 | 45 | func TestNextYield_Success(t *testing.T) { 46 | // test with random sleeping 47 | testNextYield(t, true) 48 | 49 | for i := 0; i < 1000; i++ { 50 | testNextYield(t, false) 51 | } 52 | } 53 | 54 | func TestNextYield_Error(t *testing.T) { 55 | expectedError := errors.New("expected error") 56 | 57 | next, hasNext := builders.NextYield(func(yield func(...any)) error { 58 | return expectedError 59 | }) 60 | 61 | for hasNext() { 62 | _, err := next() 63 | require.Error(t, err, expectedError.Error()) 64 | } 65 | } 66 | 67 | func TestNextYield_NoRows(t *testing.T) { 68 | _, hasNext := builders.NextYield(func(yield func(...any)) error { 69 | time.Sleep(1 * time.Second) 70 | return nil 71 | }) 72 | 73 | require.Equal(t, false, hasNext()) 74 | } 75 | 76 | func TestNextYield_SingleRow(t *testing.T) { 77 | r := require.New(t) 78 | next, hasNext := builders.NextYield(func(yield func(...any)) error { 79 | yield(1) 80 | time.Sleep(1 * time.Second) 81 | return nil 82 | }) 83 | 84 | r.True(hasNext()) 85 | 86 | row, err := next() 87 | r.NoError(err) 88 | r.Equal(1, len(row)) 89 | r.Equal(1, row[0]) 90 | 91 | r.Equal(false, hasNext()) 92 | } 93 | -------------------------------------------------------------------------------- /dbee/core/builders/result.go: -------------------------------------------------------------------------------- 1 | package builders 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | 7 | "github.com/kndndrj/nvim-dbee/dbee/core" 8 | ) 9 | 10 | var _ core.ResultStream = (*ResultStream)(nil) 11 | 12 | type ResultStream struct { 13 | next func() (core.Row, error) 14 | hasNext func() bool 15 | closes []func() 16 | meta *core.Meta 17 | header core.Header 18 | once sync.Once 19 | } 20 | 21 | func (r *ResultStream) AddCallback(fn func()) { 22 | r.closes = append(r.closes, fn) 23 | } 24 | 25 | func (r *ResultStream) Meta() *core.Meta { 26 | return r.meta 27 | } 28 | 29 | func (r *ResultStream) Header() core.Header { 30 | return r.header 31 | } 32 | 33 | func (r *ResultStream) HasNext() bool { 34 | return r.hasNext() 35 | } 36 | 37 | func (r *ResultStream) Next() (core.Row, error) { 38 | rows, err := r.next() 39 | if err != nil || rows == nil { 40 | r.Close() 41 | return nil, err 42 | } 43 | return rows, nil 44 | } 45 | 46 | func (r *ResultStream) Close() { 47 | r.once.Do(func() { 48 | for _, fn := range r.closes { 49 | if fn != nil { 50 | fn() 51 | } 52 | } 53 | }) 54 | 55 | r.hasNext = func() bool { 56 | return false 57 | } 58 | } 59 | 60 | // ResultStreamBuilder builds the rows 61 | type ResultStreamBuilder struct { 62 | next func() (core.Row, error) 63 | hasNext func() bool 64 | header core.Header 65 | closes []func() 66 | meta *core.Meta 67 | } 68 | 69 | func NewResultStreamBuilder() *ResultStreamBuilder { 70 | return &ResultStreamBuilder{ 71 | next: func() (core.Row, error) { return nil, errors.New("no next row") }, 72 | hasNext: func() bool { return false }, 73 | header: core.Header{}, 74 | meta: &core.Meta{}, 75 | } 76 | } 77 | 78 | func (b *ResultStreamBuilder) WithNextFunc(fn func() (core.Row, error), has func() bool) *ResultStreamBuilder { 79 | b.next = fn 80 | b.hasNext = has 81 | return b 82 | } 83 | 84 | func (b *ResultStreamBuilder) WithHeader(header core.Header) *ResultStreamBuilder { 85 | b.header = header 86 | return b 87 | } 88 | 89 | func (b *ResultStreamBuilder) WithCloseFunc(fn func()) *ResultStreamBuilder { 90 | b.closes = append(b.closes, fn) 91 | return b 92 | } 93 | 94 | func (b *ResultStreamBuilder) WithMeta(meta *core.Meta) *ResultStreamBuilder { 95 | b.meta = meta 96 | return b 97 | } 98 | 99 | func (b *ResultStreamBuilder) Build() *ResultStream { 100 | return &ResultStream{ 101 | next: b.next, 102 | hasNext: b.hasNext, 103 | header: b.header, 104 | closes: b.closes, 105 | meta: b.meta, 106 | once: sync.Once{}, 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /dbee/core/call_state.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | type CallState int 4 | 5 | const ( 6 | CallStateUnknown CallState = iota 7 | CallStateExecuting 8 | CallStateExecutingFailed 9 | CallStateRetrieving 10 | CallStateRetrievingFailed 11 | CallStateArchived 12 | CallStateArchiveFailed 13 | CallStateCanceled 14 | ) 15 | 16 | func CallStateFromString(s string) CallState { 17 | switch s { 18 | case CallStateUnknown.String(): 19 | return CallStateUnknown 20 | 21 | case CallStateExecuting.String(): 22 | return CallStateExecuting 23 | case CallStateExecutingFailed.String(): 24 | return CallStateExecutingFailed 25 | 26 | case CallStateRetrieving.String(): 27 | return CallStateRetrieving 28 | case CallStateRetrievingFailed.String(): 29 | return CallStateRetrievingFailed 30 | 31 | case CallStateArchived.String(): 32 | return CallStateArchived 33 | case CallStateArchiveFailed.String(): 34 | return CallStateArchiveFailed 35 | 36 | case CallStateCanceled.String(): 37 | return CallStateCanceled 38 | 39 | default: 40 | return CallStateUnknown 41 | } 42 | } 43 | 44 | func (s CallState) String() string { 45 | switch s { 46 | case CallStateUnknown: 47 | return "unknown" 48 | 49 | case CallStateExecuting: 50 | return "executing" 51 | case CallStateExecutingFailed: 52 | return "executing_failed" 53 | 54 | case CallStateRetrieving: 55 | return "retrieving" 56 | case CallStateRetrievingFailed: 57 | return "retrieving_failed" 58 | 59 | case CallStateArchived: 60 | return "archived" 61 | case CallStateArchiveFailed: 62 | return "archive_failed" 63 | 64 | case CallStateCanceled: 65 | return "canceled" 66 | 67 | default: 68 | return "unknown" 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /dbee/core/connection_params.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import "encoding/json" 4 | 5 | type ConnectionParams struct { 6 | ID ConnectionID 7 | Name string 8 | Type string 9 | URL string 10 | } 11 | 12 | // Expand returns a copy of the original parameters with expanded fields 13 | func (p *ConnectionParams) Expand() *ConnectionParams { 14 | return &ConnectionParams{ 15 | ID: ConnectionID(expandOrDefault(string(p.ID))), 16 | Name: expandOrDefault(p.Name), 17 | Type: expandOrDefault(p.Type), 18 | URL: expandOrDefault(p.URL), 19 | } 20 | } 21 | 22 | func (cp *ConnectionParams) MarshalJSON() ([]byte, error) { 23 | return json.Marshal(struct { 24 | ID string `json:"id"` 25 | Name string `json:"name"` 26 | Type string `json:"type"` 27 | URL string `json:"url"` 28 | }{ 29 | ID: string(cp.ID), 30 | Name: cp.Name, 31 | Type: cp.Type, 32 | URL: cp.URL, 33 | }) 34 | } 35 | -------------------------------------------------------------------------------- /dbee/core/expand.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "os" 7 | "os/exec" 8 | "strings" 9 | "text/template" 10 | ) 11 | 12 | func expand(value string) (string, error) { 13 | tmpl, err := template.New("expand_variables"). 14 | Funcs(template.FuncMap{ 15 | "env": os.Getenv, 16 | "exec": execCommand, 17 | }). 18 | Parse(value) 19 | if err != nil { 20 | return "", err 21 | } 22 | 23 | var out bytes.Buffer 24 | err = tmpl.Execute(&out, nil) 25 | if err != nil { 26 | return "", err 27 | } 28 | 29 | return out.String(), nil 30 | } 31 | 32 | func execCommand(line string) (string, error) { 33 | if strings.Contains(line, " | ") { 34 | out, err := exec.Command("sh", "-c", line).Output() 35 | return strings.TrimSpace(string(out)), err 36 | } 37 | 38 | l := strings.Split(line, " ") 39 | if len(l) < 1 { 40 | return "", errors.New("no command provided") 41 | } 42 | cmd := l[0] 43 | args := l[1:] 44 | 45 | out, err := exec.Command(cmd, args...).Output() 46 | return strings.TrimSpace(string(out)), err 47 | } 48 | 49 | // expandOrDefault silently suppresses errors. 50 | func expandOrDefault(value string) string { 51 | ex, err := expand(value) 52 | if err != nil { 53 | return value 54 | } 55 | return ex 56 | } 57 | -------------------------------------------------------------------------------- /dbee/core/expand_test.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func TestExpand(t *testing.T) { 11 | r := require.New(t) 12 | 13 | testCases := []struct { 14 | input string 15 | expected string 16 | }{ 17 | {"normal string", "normal string"}, 18 | {"{{ env `HOME` }}", os.Getenv("HOME")}, 19 | {"{{ exec `echo \"hello\nbuddy\" | grep buddy` }}", "buddy"}, 20 | } 21 | 22 | for _, tc := range testCases { 23 | actual, err := expand(tc.input) 24 | r.NoError(err) 25 | 26 | r.Equal(tc.expected, actual) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /dbee/core/format/csv.go: -------------------------------------------------------------------------------- 1 | package format 2 | 3 | import ( 4 | "bytes" 5 | "encoding/csv" 6 | "fmt" 7 | 8 | "github.com/kndndrj/nvim-dbee/dbee/core" 9 | ) 10 | 11 | var _ core.Formatter = (*CSV)(nil) 12 | 13 | type CSV struct{} 14 | 15 | func NewCSV() *CSV { 16 | return &CSV{} 17 | } 18 | 19 | func (cf *CSV) parseSchemaFul(header core.Header, rows []core.Row) [][]string { 20 | data := [][]string{ 21 | header, 22 | } 23 | for _, row := range rows { 24 | var csvRow []string 25 | for _, rec := range row { 26 | csvRow = append(csvRow, fmt.Sprint(rec)) 27 | } 28 | data = append(data, csvRow) 29 | } 30 | 31 | return data 32 | } 33 | 34 | func (cf *CSV) Format(header core.Header, rows []core.Row, _ *core.FormatterOptions) ([]byte, error) { 35 | // parse as if schema is defined regardles of schema presence in the result 36 | data := cf.parseSchemaFul(header, rows) 37 | 38 | b := new(bytes.Buffer) 39 | w := csv.NewWriter(b) 40 | 41 | err := w.WriteAll(data) 42 | if err != nil { 43 | return nil, fmt.Errorf("w.WriteAll: %w", err) 44 | } 45 | 46 | return b.Bytes(), nil 47 | } 48 | -------------------------------------------------------------------------------- /dbee/core/format/json.go: -------------------------------------------------------------------------------- 1 | package format 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/kndndrj/nvim-dbee/dbee/core" 8 | ) 9 | 10 | var _ core.Formatter = (*JSON)(nil) 11 | 12 | type JSON struct{} 13 | 14 | func NewJSON() *JSON { 15 | return &JSON{} 16 | } 17 | 18 | func (jf *JSON) parseSchemaFul(header core.Header, rows []core.Row) []map[string]any { 19 | var data []map[string]any 20 | 21 | for _, row := range rows { 22 | record := make(map[string]any, len(row)) 23 | for i, val := range row { 24 | var h string 25 | if i < len(header) { 26 | h = header[i] 27 | } else { 28 | h = fmt.Sprintf("", i) 29 | } 30 | record[h] = val 31 | } 32 | data = append(data, record) 33 | } 34 | 35 | return data 36 | } 37 | 38 | func (jf *JSON) parseSchemaLess(header core.Header, rows []core.Row) []any { 39 | var data []any 40 | 41 | for _, row := range rows { 42 | if len(row) == 1 { 43 | data = append(data, row[0]) 44 | } else if len(row) > 1 { 45 | data = append(data, row) 46 | } 47 | } 48 | return data 49 | } 50 | 51 | func (jf *JSON) Format(header core.Header, rows []core.Row, opts *core.FormatterOptions) ([]byte, error) { 52 | var data any 53 | switch opts.SchemaType { 54 | case core.SchemaLess: 55 | data = jf.parseSchemaLess(header, rows) 56 | case core.SchemaFul: 57 | fallthrough 58 | default: 59 | data = jf.parseSchemaFul(header, rows) 60 | } 61 | 62 | out, err := json.MarshalIndent(data, "", " ") 63 | if err != nil { 64 | return nil, fmt.Errorf("json.MarshalIndent: %w", err) 65 | } 66 | 67 | return out, nil 68 | } 69 | -------------------------------------------------------------------------------- /dbee/core/mock/adapter.go: -------------------------------------------------------------------------------- 1 | package mock 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/kndndrj/nvim-dbee/dbee/core" 8 | ) 9 | 10 | var _ core.Driver = (*driver)(nil) 11 | 12 | type driver struct { 13 | data []core.Row 14 | config *adapterConfig 15 | } 16 | 17 | func (d *driver) Query(ctx context.Context, query string) (core.ResultStream, error) { 18 | eff, ok := d.config.querySideEffects[query] 19 | if ok { 20 | err := eff(ctx) 21 | if err != nil { 22 | return nil, fmt.Errorf("side effect error: %w", err) 23 | } 24 | } 25 | 26 | return NewResultStream(d.data, d.config.resultStreamOptions...), nil 27 | } 28 | 29 | func (d *driver) Structure() ([]*core.Structure, error) { 30 | var structure []*core.Structure 31 | 32 | for table := range d.config.tableColumns { 33 | structure = append(structure, &core.Structure{ 34 | Name: table, 35 | Type: core.StructureTypeTable, 36 | }) 37 | } 38 | 39 | return structure, nil 40 | } 41 | 42 | func (d *driver) Columns(opts *core.TableOptions) ([]*core.Column, error) { 43 | columns, ok := d.config.tableColumns[opts.Table] 44 | if !ok { 45 | return nil, fmt.Errorf("unknown table: %s", opts.Table) 46 | } 47 | 48 | return columns, nil 49 | } 50 | 51 | func (d *driver) Close() {} 52 | 53 | var _ core.Adapter = (*Adapter)(nil) 54 | 55 | type Adapter struct { 56 | data []core.Row 57 | config *adapterConfig 58 | } 59 | 60 | func NewAdapter(data []core.Row, opts ...AdapterOption) *Adapter { 61 | config := &adapterConfig{ 62 | querySideEffects: make(map[string]func(context.Context) error), 63 | tableHelpers: make(map[string]string), 64 | tableColumns: make(map[string][]*core.Column), 65 | 66 | resultStreamOptions: []ResultStreamOption{}, 67 | } 68 | for _, opt := range opts { 69 | opt(config) 70 | } 71 | 72 | return &Adapter{ 73 | data: data, 74 | config: config, 75 | } 76 | } 77 | 78 | func (a *Adapter) Connect(_ string) (core.Driver, error) { 79 | return &driver{ 80 | data: a.data, 81 | config: a.config, 82 | }, nil 83 | } 84 | 85 | func (a *Adapter) GetHelpers(opts *core.TableOptions) map[string]string { 86 | return a.config.tableHelpers 87 | } 88 | -------------------------------------------------------------------------------- /dbee/core/mock/adapter_options.go: -------------------------------------------------------------------------------- 1 | package mock 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/kndndrj/nvim-dbee/dbee/core" 7 | ) 8 | 9 | type adapterConfig struct { 10 | querySideEffects map[string]func(context.Context) error 11 | tableHelpers map[string]string 12 | tableColumns map[string][]*core.Column 13 | 14 | resultStreamOptions []ResultStreamOption 15 | } 16 | 17 | type AdapterOption func(*adapterConfig) 18 | 19 | func AdapterWithQuerySideEffect(query string, sideEffect func(context.Context) error) AdapterOption { 20 | return func(c *adapterConfig) { 21 | _, ok := c.querySideEffects[query] 22 | if ok { 23 | panic("side effect already registered for query: " + query) 24 | } 25 | 26 | c.querySideEffects[query] = sideEffect 27 | } 28 | } 29 | 30 | func AdapterWithTableHelper(name string, query string) AdapterOption { 31 | return func(c *adapterConfig) { 32 | _, ok := c.tableHelpers[name] 33 | if ok { 34 | panic("query already registered for table helper: " + name) 35 | } 36 | 37 | c.tableHelpers[name] = query 38 | } 39 | } 40 | 41 | func AdapterWithTableDefinition(table string, columns []*core.Column) AdapterOption { 42 | return func(c *adapterConfig) { 43 | _, ok := c.tableColumns[table] 44 | if ok { 45 | panic("columns already registered for table: " + table) 46 | } 47 | 48 | c.tableColumns[table] = columns 49 | } 50 | } 51 | 52 | func AdapterWithResultStreamOpts(opts ...ResultStreamOption) AdapterOption { 53 | return func(c *adapterConfig) { 54 | c.resultStreamOptions = append(c.resultStreamOptions, opts...) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /dbee/core/mock/result.go: -------------------------------------------------------------------------------- 1 | package mock 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/kndndrj/nvim-dbee/dbee/core" 9 | ) 10 | 11 | func newNext(rows []core.Row) (func() (core.Row, error), func() bool) { 12 | index := 0 13 | 14 | hasNext := func() bool { 15 | return index < len(rows) 16 | } 17 | 18 | // iterator functions 19 | next := func() (core.Row, error) { 20 | if !hasNext() { 21 | return nil, errors.New("no next row") 22 | } 23 | 24 | row := rows[index] 25 | index++ 26 | return row, nil 27 | } 28 | 29 | return next, hasNext 30 | } 31 | 32 | type ResultStream struct { 33 | next func() (core.Row, error) 34 | hasNext func() bool 35 | config *resultStreamConfig 36 | } 37 | 38 | func makeDefaultHeader(rows []core.Row) core.Header { 39 | var header core.Header 40 | if len(rows) > 0 { 41 | for i := range rows[0] { 42 | header = append(header, fmt.Sprintf("header_%d", i)) 43 | } 44 | } 45 | return header 46 | } 47 | 48 | // NewResultStream returns a mocked result stream with provided rows. 49 | // It creates a header that matches the number of columns in the first row 50 | // in form of: , , etc. 51 | func NewResultStream(rows []core.Row, opts ...ResultStreamOption) *ResultStream { 52 | config := &resultStreamConfig{ 53 | nextSleep: 0, 54 | meta: &core.Meta{}, 55 | header: makeDefaultHeader(rows), 56 | } 57 | for _, opt := range opts { 58 | opt(config) 59 | } 60 | 61 | next, hasNext := newNext(rows) 62 | 63 | return &ResultStream{ 64 | next: next, 65 | hasNext: hasNext, 66 | config: config, 67 | } 68 | } 69 | 70 | func (rs *ResultStream) Meta() *core.Meta { 71 | return rs.config.meta 72 | } 73 | 74 | func (rs *ResultStream) Header() core.Header { 75 | return rs.config.header 76 | } 77 | 78 | func (rs *ResultStream) Next() (core.Row, error) { 79 | time.Sleep(rs.config.nextSleep) 80 | return rs.next() 81 | } 82 | 83 | func (rs *ResultStream) HasNext() bool { 84 | return rs.hasNext() 85 | } 86 | 87 | func (rs *ResultStream) Close() {} 88 | 89 | // NewRows returns a slice of rows in form of: 90 | // 91 | // { (int), "row_"(string) } 92 | // 93 | // where the first index is "from" and the last one is one less than "to". 94 | func NewRows(from, to int) []core.Row { 95 | var rows []core.Row 96 | 97 | for i := from; i < to; i++ { 98 | rows = append(rows, core.Row{i, fmt.Sprintf("row_%d", i)}) 99 | } 100 | return rows 101 | } 102 | -------------------------------------------------------------------------------- /dbee/core/mock/result_options.go: -------------------------------------------------------------------------------- 1 | package mock 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/kndndrj/nvim-dbee/dbee/core" 7 | ) 8 | 9 | type resultStreamConfig struct { 10 | nextSleep time.Duration 11 | meta *core.Meta 12 | header core.Header 13 | } 14 | 15 | type ResultStreamOption func(*resultStreamConfig) 16 | 17 | func ResultStreamWithNextSleep(s time.Duration) ResultStreamOption { 18 | return func(c *resultStreamConfig) { 19 | c.nextSleep = s 20 | } 21 | } 22 | 23 | func ResultStreamWithMeta(meta *core.Meta) ResultStreamOption { 24 | return func(c *resultStreamConfig) { 25 | c.meta = meta 26 | } 27 | } 28 | 29 | func ResultStreamWithHeader(header core.Header) ResultStreamOption { 30 | return func(c *resultStreamConfig) { 31 | c.header = header 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /dbee/core/result.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | var ErrInvalidRange = func(from, to int) error { return fmt.Errorf("invalid selection range: %d ... %d", from, to) } 11 | 12 | // Result is the cached form of the ResultStream iterator 13 | type Result struct { 14 | header Header 15 | meta *Meta 16 | rows []Row 17 | 18 | isDrained bool 19 | isFilled bool 20 | writeMutex sync.Mutex 21 | readMutex sync.RWMutex 22 | } 23 | 24 | // SetIter sets the ResultStream iterator to result. 25 | // This can be done only once! 26 | func (cr *Result) SetIter(iter ResultStream, onFillStart func()) error { 27 | // lock write mutex 28 | cr.writeMutex.Lock() 29 | defer cr.writeMutex.Unlock() 30 | 31 | // close iterator on return 32 | defer iter.Close() 33 | 34 | cr.header = iter.Header() 35 | cr.meta = iter.Meta() 36 | cr.rows = make([]Row, 0) 37 | 38 | cr.isDrained = false 39 | cr.isFilled = true 40 | 41 | defer func() { cr.isDrained = true }() 42 | 43 | // trigger callback 44 | if onFillStart != nil { 45 | onFillStart() 46 | } 47 | 48 | // drain the iterator 49 | for iter.HasNext() { 50 | row, err := iter.Next() 51 | if err != nil { 52 | cr.isFilled = false 53 | return err 54 | } 55 | 56 | cr.rows = append(cr.rows, row) 57 | } 58 | 59 | return nil 60 | } 61 | 62 | func (cr *Result) Wipe() { 63 | // lock write and read mutexes 64 | cr.writeMutex.Lock() 65 | defer cr.writeMutex.Unlock() 66 | cr.readMutex.Lock() 67 | defer cr.readMutex.Unlock() 68 | 69 | // clear everything 70 | cr.header = Header{} 71 | cr.meta = &Meta{} 72 | cr.rows = []Row{} 73 | cr.isDrained = false 74 | cr.isFilled = false 75 | } 76 | 77 | func (cr *Result) Format(formatter Formatter, from, to int) ([]byte, error) { 78 | rows, fromAdjusted, _, err := cr.getRows(from, to) 79 | if err != nil { 80 | return nil, fmt.Errorf("cr.Rows: %w", err) 81 | } 82 | 83 | opts := &FormatterOptions{ 84 | SchemaType: cr.meta.SchemaType, 85 | ChunkStart: fromAdjusted, 86 | } 87 | 88 | f, err := formatter.Format(cr.header, rows, opts) 89 | if err != nil { 90 | return nil, fmt.Errorf("formatter.Format: %w", err) 91 | } 92 | 93 | return f, nil 94 | } 95 | 96 | func (cr *Result) Len() int { 97 | return len(cr.rows) 98 | } 99 | 100 | func (cr *Result) IsEmpty() bool { 101 | return !cr.isFilled 102 | } 103 | 104 | func (cr *Result) Header() Header { 105 | return cr.header 106 | } 107 | 108 | func (cr *Result) Meta() *Meta { 109 | return cr.meta 110 | } 111 | 112 | func (cr *Result) Rows(from, to int) ([]Row, error) { 113 | rows, _, _, err := cr.getRows(from, to) 114 | return rows, err 115 | } 116 | 117 | // getRows returns the row range and adjusted from-to values 118 | func (cr *Result) getRows(from, to int) (rows []Row, rangeFrom, rangeTo int, err error) { 119 | // increment the read mutex 120 | cr.readMutex.RLock() 121 | defer cr.readMutex.RUnlock() 122 | 123 | // validation 124 | if (from < 0 && to < 0) || (from >= 0 && to >= 0) { 125 | if from > to { 126 | return nil, 0, 0, ErrInvalidRange(from, to) 127 | } 128 | } 129 | // undefined -> error 130 | if from < 0 && to >= 0 { 131 | return nil, 0, 0, ErrInvalidRange(from, to) 132 | } 133 | 134 | // timeout context 135 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) 136 | defer cancel() 137 | 138 | // Wait for drain, available index or timeout 139 | for !cr.isDrained && (to < 0 || to > len(cr.rows)) { 140 | 141 | if err := ctx.Err(); err != nil { 142 | return nil, 0, 0, fmt.Errorf("cache flushing timeout exceeded: %s", err) 143 | } 144 | time.Sleep(50 * time.Millisecond) 145 | } 146 | 147 | // calculate range 148 | length := len(cr.rows) 149 | if from < 0 { 150 | from += length + 1 151 | if from < 0 { 152 | from = 0 153 | } 154 | } 155 | if to < 0 { 156 | to += length + 1 157 | if to < 0 { 158 | to = 0 159 | } 160 | } 161 | 162 | if from > length { 163 | from = length 164 | } 165 | if to > length { 166 | to = length 167 | } 168 | 169 | return cr.rows[from:to], from, to, nil 170 | } 171 | -------------------------------------------------------------------------------- /dbee/core/result_test.go: -------------------------------------------------------------------------------- 1 | package core_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/require" 8 | 9 | "github.com/kndndrj/nvim-dbee/dbee/core" 10 | "github.com/kndndrj/nvim-dbee/dbee/core/mock" 11 | ) 12 | 13 | func TestResult(t *testing.T) { 14 | type testCase struct { 15 | name string 16 | from int 17 | to int 18 | input []core.Row 19 | expected []core.Row 20 | expectedError error 21 | } 22 | 23 | testCases := []testCase{ 24 | { 25 | name: "get all", 26 | from: 0, 27 | to: -1, 28 | input: mock.NewRows(0, 10), 29 | expected: mock.NewRows(0, 10), 30 | expectedError: nil, 31 | }, 32 | { 33 | name: "get basic range", 34 | from: 0, 35 | to: 3, 36 | input: mock.NewRows(0, 10), 37 | expected: mock.NewRows(0, 3), 38 | expectedError: nil, 39 | }, 40 | { 41 | name: "get last 2", 42 | from: -3, 43 | to: -1, 44 | input: mock.NewRows(0, 10), 45 | expected: mock.NewRows(8, 10), 46 | expectedError: nil, 47 | }, 48 | { 49 | name: "get only one", 50 | from: 0, 51 | to: 1, 52 | input: mock.NewRows(0, 10), 53 | expected: mock.NewRows(0, 1), 54 | expectedError: nil, 55 | }, 56 | 57 | { 58 | name: "invalid range", 59 | from: 5, 60 | to: 1, 61 | input: mock.NewRows(0, 10), 62 | expected: nil, 63 | expectedError: core.ErrInvalidRange(5, 1), 64 | }, 65 | { 66 | name: "invalid range (even if 10 can be higher than -1, its undefined and should fail)", 67 | from: -5, 68 | to: 10, 69 | input: mock.NewRows(0, 10), 70 | expected: nil, 71 | expectedError: core.ErrInvalidRange(-5, 10), 72 | }, 73 | 74 | { 75 | name: "wait for available index", 76 | from: 0, 77 | to: 3, 78 | input: mock.NewRows(0, 10), 79 | expected: mock.NewRows(0, 3), 80 | expectedError: nil, 81 | }, 82 | { 83 | name: "wait for all to be drained", 84 | from: 0, 85 | to: -1, 86 | input: mock.NewRows(0, 10), 87 | expected: mock.NewRows(0, 10), 88 | expectedError: nil, 89 | }, 90 | } 91 | 92 | result := new(core.Result) 93 | 94 | for _, tc := range testCases { 95 | t.Run(tc.name, func(t *testing.T) { 96 | r := require.New(t) 97 | // wipe any previous result 98 | result.Wipe() 99 | 100 | // set a new iterator with input 101 | err := result.SetIter(mock.NewResultStream(tc.input, mock.ResultStreamWithNextSleep(300*time.Millisecond)), nil) 102 | r.NoError(err) 103 | 104 | rows, err := result.Rows(tc.from, tc.to) 105 | if err != nil { 106 | r.ErrorContains(tc.expectedError, err.Error()) 107 | } 108 | r.Equal(rows, tc.expected) 109 | }) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /dbee/core/types.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | ) 7 | 8 | type SchemaType int 9 | 10 | const ( 11 | SchemaFul SchemaType = iota 12 | SchemaLess 13 | ) 14 | 15 | type ( 16 | // FormatterOptions provide various options for formatters 17 | FormatterOptions struct { 18 | SchemaType SchemaType 19 | ChunkStart int 20 | } 21 | 22 | // Formatter converts header and rows to bytes 23 | Formatter interface { 24 | Format(header Header, rows []Row, opts *FormatterOptions) ([]byte, error) 25 | } 26 | ) 27 | 28 | type ( 29 | // Row and Header are attributes of IterResult iterator 30 | Row []any 31 | Header []string 32 | 33 | // Meta holds metadata 34 | Meta struct { 35 | // type of schema (schemaful or schemaless) 36 | SchemaType SchemaType 37 | } 38 | 39 | // ResultStream is a result from executed query and has a form of an iterator 40 | ResultStream interface { 41 | Meta() *Meta 42 | Header() Header 43 | Next() (Row, error) 44 | HasNext() bool 45 | Close() 46 | } 47 | ) 48 | 49 | type StructureType int 50 | 51 | const ( 52 | StructureTypeNone StructureType = iota 53 | StructureTypeTable 54 | StructureTypeView 55 | StructureTypeMaterializedView 56 | StructureTypeStreamingTable 57 | StructureTypeSink 58 | StructureTypeSource 59 | StructureTypeManaged 60 | StructureTypeSchema 61 | ) 62 | 63 | // String returns the string representation of the StructureType 64 | func (s StructureType) String() string { 65 | switch s { 66 | case StructureTypeNone: 67 | return "" 68 | case StructureTypeTable: 69 | return "table" 70 | case StructureTypeView: 71 | return "view" 72 | case StructureTypeMaterializedView: 73 | return "materialized_view" 74 | case StructureTypeStreamingTable: 75 | return "streaming_table" 76 | case StructureTypeSink: 77 | return "sink" 78 | case StructureTypeSource: 79 | return "source" 80 | case StructureTypeManaged: 81 | return "managed" 82 | case StructureTypeSchema: 83 | return "schema" 84 | default: 85 | return "" 86 | } 87 | } 88 | 89 | // ErrInsufficienStructureInfo is returned when the structure info is insufficient 90 | var ErrInsufficienStructureInfo = errors.New("structure info is insufficient. Expected at least 'schema', 'table' and 'type' columns in that order") 91 | 92 | // GetGenericStructure returns a generic structure for an adapter. 93 | // The rows `ResultStream` need to be a query which returns at least 3 string columns: 94 | // 1. schema 95 | // 2. table 96 | // 3. type 97 | // 98 | // in this order. 99 | // 100 | // The `structTypeFn` function is used to determine the `StructureType` based on the type string. 101 | // `structTypeFn` is adapter specific based on `type` pattern. 102 | // The function should return `StructureTypeNone` if the type is unknown. 103 | func GetGenericStructure(rows ResultStream, structTypeFn func(string) StructureType) ([]*Structure, error) { 104 | children := make(map[string][]*Structure) 105 | 106 | for rows.HasNext() { 107 | row, err := rows.Next() 108 | if err != nil { 109 | return nil, err 110 | } 111 | if len(row) < 3 { 112 | return nil, ErrInsufficienStructureInfo 113 | } 114 | 115 | errCast := errors.New("expected string, got %T") 116 | schema, ok := row[0].(string) 117 | if !ok { 118 | return nil, errCast 119 | } 120 | table, ok := row[1].(string) 121 | if !ok { 122 | return nil, errCast 123 | } 124 | typ, ok := row[2].(string) 125 | if !ok { 126 | return nil, errCast 127 | } 128 | 129 | children[schema] = append(children[schema], &Structure{ 130 | Name: table, 131 | Schema: schema, 132 | Type: structTypeFn(typ), 133 | }) 134 | } 135 | 136 | structure := make([]*Structure, 0, len(children)) 137 | 138 | for schema, models := range children { 139 | structure = append(structure, &Structure{ 140 | Name: schema, 141 | Schema: schema, 142 | Type: StructureTypeSchema, 143 | Children: models, 144 | }) 145 | } 146 | 147 | return structure, nil 148 | } 149 | 150 | func StructureTypeFromString(s string) StructureType { 151 | switch strings.ToLower(s) { 152 | case "table": 153 | return StructureTypeTable 154 | case "view": 155 | return StructureTypeView 156 | default: 157 | return StructureTypeNone 158 | } 159 | } 160 | 161 | // Structure represents the structure of a single database 162 | type Structure struct { 163 | // Name to be displayed 164 | Name string 165 | Schema string 166 | // Type of layout 167 | Type StructureType 168 | // Children layout nodes 169 | Children []*Structure 170 | } 171 | 172 | type Column struct { 173 | // Column name 174 | Name string 175 | // Database data type 176 | Type string 177 | } 178 | -------------------------------------------------------------------------------- /dbee/handler/call_log.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/kndndrj/nvim-dbee/dbee/core" 9 | ) 10 | 11 | func (h *Handler) storeCallLog() error { 12 | store := make(map[core.ConnectionID][]*core.Call) 13 | 14 | for connID := range h.lookupConnection { 15 | calls, err := h.ConnectionGetCalls(connID) 16 | if err != nil || len(calls) < 1 { 17 | continue 18 | } 19 | store[connID] = calls 20 | } 21 | 22 | b, err := json.MarshalIndent(store, "", " ") 23 | if err != nil { 24 | return fmt.Errorf("json.MarshalIndent: %w", err) 25 | } 26 | 27 | file, err := os.Create(callLogFileName) 28 | if err != nil { 29 | return fmt.Errorf("os.Create: %s", err) 30 | } 31 | defer file.Close() 32 | 33 | _, err = file.Write(b) 34 | if err != nil { 35 | return fmt.Errorf("file.Write: %w", err) 36 | } 37 | 38 | return nil 39 | } 40 | 41 | func (h *Handler) restoreCallLog() error { 42 | file, err := os.Open(callLogFileName) 43 | if err != nil { 44 | return fmt.Errorf("os.Open: %w", err) 45 | } 46 | defer file.Close() 47 | 48 | decoder := json.NewDecoder(file) 49 | 50 | var store map[core.ConnectionID][]*core.Call 51 | 52 | err = decoder.Decode(&store) 53 | if err != nil { 54 | return fmt.Errorf("decoder.Decode: %w", err) 55 | } 56 | 57 | for connID, calls := range store { 58 | callIDs := make([]core.CallID, len(calls)) 59 | 60 | // fill call lookup 61 | for i, c := range calls { 62 | h.lookupCall[c.GetID()] = c 63 | callIDs[i] = c.GetID() 64 | } 65 | 66 | // add to conn-call lookup 67 | h.lookupConnectionCall[connID] = append(h.lookupConnectionCall[connID], callIDs...) 68 | } 69 | 70 | return nil 71 | } 72 | -------------------------------------------------------------------------------- /dbee/handler/event_bus.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/neovim/go-client/nvim" 7 | 8 | "github.com/kndndrj/nvim-dbee/dbee/core" 9 | "github.com/kndndrj/nvim-dbee/dbee/plugin" 10 | ) 11 | 12 | type eventBus struct { 13 | vim *nvim.Nvim 14 | log *plugin.Logger 15 | } 16 | 17 | func (eb *eventBus) callLua(event string, data string) { 18 | err := eb.vim.ExecLua(fmt.Sprintf(`require("dbee.handler.__events").trigger(%q, %s)`, event, data), nil) 19 | if err != nil { 20 | eb.log.Infof("eb.vim.ExecLua: %s", err) 21 | } 22 | } 23 | 24 | func (eb *eventBus) CallStateChanged(call *core.Call) { 25 | errMsg := "nil" 26 | if err := call.Err(); err != nil { 27 | errMsg = fmt.Sprintf("[[%s]]", err.Error()) 28 | } 29 | 30 | data := fmt.Sprintf(`{ 31 | call = { 32 | id = %q, 33 | query = %q, 34 | state = %q, 35 | time_taken_us = %d, 36 | timestamp_us = %d, 37 | error = %s, 38 | }, 39 | }`, call.GetID(), 40 | call.GetQuery(), 41 | call.GetState().String(), 42 | call.GetTimeTaken().Microseconds(), 43 | call.GetTimestamp().UnixMicro(), 44 | errMsg) 45 | 46 | eb.callLua("call_state_changed", data) 47 | } 48 | 49 | func (eb *eventBus) CurrentConnectionChanged(id core.ConnectionID) { 50 | data := fmt.Sprintf(`{ 51 | conn_id = %q, 52 | }`, id) 53 | 54 | eb.callLua("current_connection_changed", data) 55 | } 56 | 57 | // DatabaseSelected is called when the selected database of a connection is changed. 58 | // Sends the new database name along with affected connection ID to the lua event handler. 59 | func (eb *eventBus) DatabaseSelected(id core.ConnectionID, dbname string) { 60 | data := fmt.Sprintf(`{ 61 | conn_id = %q, 62 | database_name = %q, 63 | }`, id, dbname) 64 | 65 | eb.callLua("database_selected", data) 66 | } 67 | -------------------------------------------------------------------------------- /dbee/handler/format_table.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "github.com/jedib0t/go-pretty/v6/table" 5 | "github.com/jedib0t/go-pretty/v6/text" 6 | 7 | "github.com/kndndrj/nvim-dbee/dbee/core" 8 | ) 9 | 10 | var _ core.Formatter = (*Table)(nil) 11 | 12 | type Table struct{} 13 | 14 | func newTable() *Table { 15 | return &Table{} 16 | } 17 | 18 | func (tf *Table) Format(header core.Header, rows []core.Row, opts *core.FormatterOptions) ([]byte, error) { 19 | tableHeaders := []any{""} 20 | for _, k := range header { 21 | tableHeaders = append(tableHeaders, k) 22 | } 23 | index := opts.ChunkStart 24 | 25 | var tableRows []table.Row 26 | for _, row := range rows { 27 | indexedRow := append([]any{index + 1}, row...) 28 | tableRows = append(tableRows, table.Row(indexedRow)) 29 | index += 1 30 | } 31 | 32 | t := table.NewWriter() 33 | t.AppendHeader(table.Row(tableHeaders)) 34 | t.AppendRows(tableRows) 35 | t.AppendSeparator() 36 | t.SetStyle(table.StyleLight) 37 | t.Style().Format = table.FormatOptions{ 38 | Footer: text.FormatDefault, 39 | Header: text.FormatDefault, 40 | Row: text.FormatDefault, 41 | } 42 | t.Style().Options.DrawBorder = false 43 | t.SuppressTrailingSpaces() 44 | render := t.Render() 45 | 46 | return []byte(render), nil 47 | } 48 | -------------------------------------------------------------------------------- /dbee/handler/output_buffer.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | 7 | "github.com/neovim/go-client/nvim" 8 | ) 9 | 10 | func newBuffer(vim *nvim.Nvim, buffer nvim.Buffer) *Buffer { 11 | return &Buffer{ 12 | buffer: buffer, 13 | vim: vim, 14 | } 15 | } 16 | 17 | type Buffer struct { 18 | buffer nvim.Buffer 19 | vim *nvim.Nvim 20 | } 21 | 22 | func (b *Buffer) Write(p []byte) (int, error) { 23 | scanner := bufio.NewScanner(bytes.NewReader(p)) 24 | var lines [][]byte 25 | for scanner.Scan() { 26 | lines = append(lines, []byte(scanner.Text())) 27 | } 28 | 29 | const modifiableOptionName = "modifiable" 30 | 31 | // is the buffer modifiable 32 | isModifiable := false 33 | err := b.vim.BufferOption(b.buffer, modifiableOptionName, &isModifiable) 34 | if err != nil { 35 | return 0, err 36 | } 37 | 38 | if !isModifiable { 39 | err = b.vim.SetBufferOption(b.buffer, modifiableOptionName, true) 40 | if err != nil { 41 | return 0, err 42 | } 43 | } 44 | 45 | err = b.vim.SetBufferLines(b.buffer, 0, -1, true, lines) 46 | if err != nil { 47 | return 0, err 48 | } 49 | 50 | if !isModifiable { 51 | err = b.vim.SetBufferOption(b.buffer, modifiableOptionName, false) 52 | if err != nil { 53 | return 0, err 54 | } 55 | } 56 | 57 | return len(p), nil 58 | } 59 | -------------------------------------------------------------------------------- /dbee/handler/output_yank.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/neovim/go-client/nvim" 7 | ) 8 | 9 | type YankRegister struct { 10 | vim *nvim.Nvim 11 | register string 12 | } 13 | 14 | func newYankRegister(vim *nvim.Nvim, register string) *YankRegister { 15 | return &YankRegister{ 16 | vim: vim, 17 | register: register, 18 | } 19 | } 20 | 21 | func (yr *YankRegister) Write(p []byte) (int, error) { 22 | err := yr.vim.Call("setreg", nil, yr.register, string(p)) 23 | if err != nil { 24 | return 0, fmt.Errorf("r.vim.Call: %w", err) 25 | } 26 | 27 | return len(p), err 28 | } 29 | -------------------------------------------------------------------------------- /dbee/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "log" 7 | "os" 8 | "runtime/debug" 9 | 10 | "github.com/neovim/go-client/nvim" 11 | 12 | "github.com/kndndrj/nvim-dbee/dbee/handler" 13 | "github.com/kndndrj/nvim-dbee/dbee/plugin" 14 | ) 15 | 16 | func main() { 17 | generateManifest := flag.String("manifest", "", "Generate manifest to file (filename of manifest).") 18 | getVersion := flag.Bool("version", false, "Get version and exit.") 19 | flag.Parse() 20 | 21 | // get version info 22 | if *getVersion { 23 | info, ok := debug.ReadBuildInfo() 24 | if !ok { 25 | fmt.Println("unknown") 26 | os.Exit(1) 27 | } 28 | for _, inf := range info.Settings { 29 | if inf.Key == "vcs.revision" { 30 | fmt.Println(inf.Value) 31 | return 32 | } 33 | } 34 | fmt.Println("unknown") 35 | os.Exit(1) 36 | } 37 | 38 | stdout := os.Stdout 39 | os.Stdout = os.Stderr 40 | log.SetFlags(0) 41 | 42 | v, err := nvim.New(os.Stdin, stdout, stdout, log.Printf) 43 | if err != nil { 44 | log.Fatal(err) 45 | } 46 | 47 | logger := plugin.NewLogger(v) 48 | 49 | p := plugin.New(v, logger) 50 | 51 | h := handler.New(v, logger) 52 | defer h.Close() 53 | 54 | // configure "endpoints" from handler 55 | mountEndpoints(p, h) 56 | 57 | // generate manifest 58 | if *generateManifest != "" { 59 | err := p.Manifest("nvim_dbee", "dbee", *generateManifest) 60 | if err != nil { 61 | log.Fatal(err) 62 | } 63 | log.Println("generated manifest to " + *generateManifest) 64 | return 65 | } 66 | 67 | // start server 68 | if err := v.Serve(); err != nil { 69 | log.Fatal(err) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /dbee/plugin/logger.go: -------------------------------------------------------------------------------- 1 | package plugin 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "os" 7 | "path/filepath" 8 | 9 | "github.com/neovim/go-client/nvim" 10 | ) 11 | 12 | type Logger struct { 13 | vim *nvim.Nvim 14 | logger *log.Logger 15 | file *os.File 16 | triedFileSet bool 17 | } 18 | 19 | func NewLogger(vim *nvim.Nvim) *Logger { 20 | return &Logger{ 21 | vim: vim, 22 | logger: log.New(os.Stdout, "", log.Ldate|log.Ltime), 23 | triedFileSet: false, 24 | } 25 | } 26 | 27 | func (l *Logger) setupFile() error { 28 | var fileName string 29 | err := l.vim.Call("stdpath", &fileName, "cache") 30 | if err != nil { 31 | return err 32 | } 33 | fileName = filepath.Join(fileName, "dbee", "dbee.log") 34 | 35 | file, err := os.OpenFile(fileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o666) 36 | if err != nil { 37 | return err 38 | } 39 | 40 | l.logger.SetOutput(file) 41 | return nil 42 | } 43 | 44 | func (l *Logger) Close() { 45 | if l.file != nil { 46 | l.file.Close() 47 | } 48 | } 49 | 50 | func (l *Logger) log(level, message string) { 51 | if l.file == nil && !l.triedFileSet { 52 | err := l.setupFile() 53 | if err != nil { 54 | l.logger.Print(err) 55 | } 56 | l.triedFileSet = true 57 | } 58 | 59 | l.logger.Printf("[%s]: %s", level, message) 60 | } 61 | 62 | func (l *Logger) Infof(format string, args ...any) { 63 | l.log("info", fmt.Sprintf(format, args...)) 64 | } 65 | 66 | func (l *Logger) Errorf(format string, args ...any) { 67 | l.log("error", fmt.Sprintf(format, args...)) 68 | } 69 | -------------------------------------------------------------------------------- /dbee/plugin/manifest.go: -------------------------------------------------------------------------------- 1 | package plugin 2 | 3 | const manifestLuaFile = `-- This file is automatically generated using "dbee -manifest " 4 | -- DO NOT EDIT! 5 | 6 | return function() 7 | -- Register host 8 | vim.fn["remote#host#Register"]("{{ .Host }}", "x", function() 9 | return vim.fn.jobstart({ "{{ .Executable }}" }, { 10 | rpc = true, 11 | detach = true, 12 | on_stderr = function(_, data, _) 13 | for _, line in ipairs(data) do 14 | print(line) 15 | end 16 | end, 17 | }) 18 | end) 19 | 20 | -- Manifest 21 | vim.fn["remote#host#RegisterPlugin"]("{{ .Host }}", "0", { 22 | {{- range .Specs }} 23 | { type = "{{ .Type }}", name = "{{ .Name }}", sync = {{ .Sync }}, opts = vim.empty_dict() }, 24 | {{- end }} 25 | }) 26 | end 27 | ` 28 | -------------------------------------------------------------------------------- /dbee/plugin/plugin.go: -------------------------------------------------------------------------------- 1 | package plugin 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "reflect" 7 | "sort" 8 | "text/template" 9 | 10 | "github.com/neovim/go-client/nvim" 11 | ) 12 | 13 | // Plugin represents a remote plugin. 14 | type Plugin struct { 15 | vim *nvim.Nvim 16 | pluginSpecs []*pluginSpec 17 | log *Logger 18 | } 19 | 20 | // New returns an intialized plugin. 21 | func New(v *nvim.Nvim, l *Logger) *Plugin { 22 | return &Plugin{ 23 | vim: v, 24 | log: l, 25 | } 26 | } 27 | 28 | type pluginSpec struct { 29 | sm string 30 | Type string `msgpack:"type"` 31 | Name string `msgpack:"name"` 32 | Sync bool `msgpack:"sync"` 33 | Opts map[string]string `msgpack:"opts"` 34 | } 35 | 36 | func isSync(f interface{}) bool { 37 | t := reflect.TypeOf(f) 38 | 39 | return t.Kind() == reflect.Func && t.NumOut() > 0 40 | } 41 | 42 | func (p *Plugin) handle(fn interface{}, spec *pluginSpec) { 43 | p.pluginSpecs = append(p.pluginSpecs, spec) 44 | if p.vim == nil { 45 | return 46 | } 47 | 48 | if err := p.vim.RegisterHandler(spec.sm, fn); err != nil { 49 | panic(err) 50 | } 51 | } 52 | 53 | func (p *Plugin) logReturn(method string, values []reflect.Value) { 54 | // check for return errors 55 | for _, val := range values { 56 | v := val.Interface() 57 | 58 | if v, ok := v.(error); ok && v != nil { 59 | p.log.Infof("method %q failed with error: %s", method, v) 60 | return 61 | } 62 | } 63 | 64 | p.log.Infof("method %q returned successfully", method) 65 | } 66 | 67 | // RegisterEndpoint registers fn as a handler for a vim function. The function 68 | // signature for fn is one of 69 | // 70 | // func([v *nvim.Nvim,] args {arrayType}) ({resultType}, error) 71 | // func([v *nvim.Nvim,] args {arrayType}) error 72 | // 73 | // where {arrayType} is a type that can be unmarshaled from a MessagePack 74 | // array and {resultType} is the type of function result. 75 | func (p *Plugin) RegisterEndpoint(name string, fn any) { 76 | v := reflect.ValueOf(fn) 77 | 78 | newFn := reflect.MakeFunc(v.Type(), func(args []reflect.Value) (results []reflect.Value) { 79 | p.log.Infof("calling method %q", name) 80 | ret := v.Call(args) 81 | p.logReturn(name, ret) 82 | return ret 83 | }) 84 | 85 | p.handle(newFn.Interface(), &pluginSpec{ 86 | sm: `0:function:` + name, 87 | Type: `function`, 88 | Name: name, 89 | Sync: isSync(fn), 90 | Opts: make(map[string]string), 91 | }) 92 | } 93 | 94 | func (p *Plugin) Manifest(host, executable, writeTo string) error { 95 | // Sort for consistent order on output. 96 | sort.Slice(p.pluginSpecs, func(i, j int) bool { 97 | return p.pluginSpecs[i].sm < p.pluginSpecs[j].sm 98 | }) 99 | 100 | tmpl, err := template.New("manifest_template").Parse(manifestLuaFile) 101 | if err != nil { 102 | return fmt.Errorf("template.New.Parse: %w", err) 103 | } 104 | 105 | outputFile, err := os.Create(writeTo) 106 | if err != nil { 107 | return fmt.Errorf("os.Create: %w", err) 108 | } 109 | 110 | err = tmpl.Execute(outputFile, struct { 111 | Host string 112 | Executable string 113 | Specs []*pluginSpec 114 | }{ 115 | Host: host, 116 | Executable: executable, 117 | Specs: p.pluginSpecs, 118 | }) 119 | if err != nil { 120 | return fmt.Errorf("tmpl.Execute: %w", err) 121 | } 122 | 123 | return nil 124 | } 125 | -------------------------------------------------------------------------------- /dbee/tests/README.md: -------------------------------------------------------------------------------- 1 | # Raison d'être 2 | 3 | This directory contains tests for the dbee project that are not unit tests. 4 | 5 | ## Tests: 6 | 7 | Try to follow the uber-go style guide for tests, which can be found 8 | [here](https://github.com/uber-go/guide/blob/master/style.md#test-tables). 9 | 10 | ### How to run tests 11 | 12 | [Go testcontainers](https://golang.testcontainers.org/modules) is used to run integration tests 13 | against the [adapters](./../adapters) package. 14 | 15 | Testcontainers support two types of provider, docker and podman. If `podman` executable is detected, 16 | then it will be used as the default provider. Otherwise `docker` will be used. When using podman, 17 | the ryuk container (repear) need to be run as privileged, set env variable 18 | `TESTCONTAINERS_RYUK_CONTAINER_PRIVILEGED=true` before running any tests to enable it. 19 | 20 | ### How to run tests 21 | 22 | Tests are run via (from dbee pwd): 23 | 24 | ```bash 25 | go test ./tests/... -v 26 | ``` 27 | 28 | If you want to disable cache add the `-count=1` flag: 29 | 30 | ```bash 31 | go test ./tests/... -v -count=1 32 | ``` 33 | 34 | To run a specific adapter, you can use the `-run` flag: 35 | 36 | ```bash 37 | go test ./tests/... -v -run Test 38 | ``` 39 | 40 | For example, to run the `postgres` adapter tests: 41 | 42 | ```bash 43 | go test ./tests/... -v -run TestPostgres 44 | ``` 45 | 46 | ### Add new tests 47 | 48 | Take a look at the `postgres` adapter example on how to add a new integration. Otherwise, the 49 | default documentation from [testcontainers](https://golang.testcontainers.org/modules) is always 50 | very helpful to look at. 51 | -------------------------------------------------------------------------------- /dbee/tests/integration/bigquery_integration_test.go: -------------------------------------------------------------------------------- 1 | package integration 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "testing" 7 | "time" 8 | 9 | "github.com/kndndrj/nvim-dbee/dbee/core" 10 | th "github.com/kndndrj/nvim-dbee/dbee/tests/testhelpers" 11 | "github.com/stretchr/testify/assert" 12 | tsuite "github.com/stretchr/testify/suite" 13 | tc "github.com/testcontainers/testcontainers-go" 14 | ) 15 | 16 | // BigQueryTestSuite is the test suite for the bigquery adapter. 17 | type BigQueryTestSuite struct { 18 | tsuite.Suite 19 | ctr *th.BigQueryContainer 20 | ctx context.Context 21 | d *core.Connection 22 | } 23 | 24 | // TestBigQueryTestSuite is the entrypoint for go test. 25 | func TestBigQueryTestSuite(t *testing.T) { 26 | tsuite.Run(t, new(BigQueryTestSuite)) 27 | } 28 | 29 | func (suite *BigQueryTestSuite) SetupSuite() { 30 | suite.ctx = context.Background() 31 | ctr, err := th.NewBigQueryContainer(suite.ctx, &core.ConnectionParams{ 32 | ID: "test-bigquery", 33 | Name: "test-bigquery", 34 | }) 35 | if err != nil { 36 | log.Fatal(err) 37 | } 38 | 39 | suite.ctr = ctr 40 | suite.d = ctr.Driver 41 | } 42 | 43 | func (suite *BigQueryTestSuite) TeardownSuite() { 44 | tc.CleanupContainer(suite.T(), suite.ctr) 45 | } 46 | 47 | func (suite *BigQueryTestSuite) TestShouldErrorInvalidQuery() { 48 | t := suite.T() 49 | 50 | want := "Syntax error" 51 | 52 | call := suite.d.Execute("invalid sql", func(cs core.CallState, c *core.Call) { 53 | if cs == core.CallStateExecutingFailed { 54 | assert.ErrorContains(t, c.Err(), want) 55 | } 56 | }) 57 | assert.NotNil(t, call) 58 | } 59 | 60 | func (suite *BigQueryTestSuite) TestShouldCancelQuery() { 61 | t := suite.T() 62 | want := []core.CallState{core.CallStateExecuting, core.CallStateCanceled} 63 | 64 | _, got, err := th.GetResultWithCancel(t, suite.d, "SELECT 1") 65 | assert.NoError(t, err) 66 | 67 | assert.Equal(t, want, got) 68 | } 69 | 70 | func (suite *BigQueryTestSuite) TestShouldReturnOneRow() { 71 | t := suite.T() 72 | 73 | wantStates := []core.CallState{ 74 | core.CallStateExecuting, core.CallStateRetrieving, core.CallStateArchived, 75 | } 76 | wantCols := []string{"id", "createdAt", "name"} 77 | wantRows := []core.Row{ 78 | { 79 | int64(1), 80 | "john", 81 | time.Date(2025, 1, 21, 0, 0, 0, 0, time.UTC), 82 | }, 83 | } 84 | 85 | query := "SELECT id, name, createdAt FROM `dataset_test.table_test` WHERE id = 1" 86 | 87 | gotRows, gotCols, gotStates, err := th.GetResult(t, suite.d, query) 88 | assert.NoError(t, err) 89 | 90 | assert.ElementsMatch(t, wantCols, gotCols) 91 | assert.ElementsMatch(t, wantStates, gotStates) 92 | assert.Equal(t, wantRows, gotRows) 93 | } 94 | 95 | func (suite *BigQueryTestSuite) TestShouldReturnManyRows() { 96 | t := suite.T() 97 | 98 | wantRows := []core.Row{ 99 | { 100 | int64(1), 101 | "john", 102 | time.Date(2025, 1, 21, 0, 0, 0, 0, time.UTC), 103 | }, 104 | { 105 | int64(2), 106 | "bob", 107 | time.Date(2025, 1, 21, 0, 1, 0, 0, time.UTC), 108 | }, 109 | } 110 | query := "SELECT id, name, createdAt FROM `dataset_test.table_test` WHERE id IN (1, 2)" 111 | 112 | gotRows, _, _, err := th.GetResult(t, suite.d, query) 113 | assert.NoError(t, err) 114 | assert.Equal(t, wantRows, gotRows) 115 | } 116 | 117 | func (suite *BigQueryTestSuite) TestShouldReturnStructure() { 118 | t := suite.T() 119 | 120 | wantSomeSchema, wantSomeTable := "dataset_test", "table_test" 121 | 122 | structure, err := suite.d.GetStructure() 123 | assert.NoError(t, err) 124 | 125 | gotSchemas := th.GetSchemas(t, structure) 126 | assert.Contains(t, gotSchemas, wantSomeSchema) 127 | 128 | gotTables := th.GetModels(t, structure, core.StructureTypeTable) 129 | assert.Contains(t, gotTables, wantSomeTable) 130 | } 131 | 132 | func (suite *BigQueryTestSuite) TestShouldReturnColumns() { 133 | t := suite.T() 134 | 135 | want := []*core.Column{ 136 | {Name: "id", Type: "INTEGER"}, 137 | {Name: "name", Type: "STRING"}, 138 | {Name: "createdAt", Type: "TIMESTAMP"}, 139 | } 140 | 141 | got, err := suite.d.GetColumns(&core.TableOptions{ 142 | Table: "table_test", 143 | Schema: "dataset_test", 144 | Materialization: core.StructureTypeTable, 145 | }) 146 | 147 | assert.NoError(t, err) 148 | assert.Equal(t, want, got) 149 | } 150 | -------------------------------------------------------------------------------- /dbee/tests/integration/docs.go: -------------------------------------------------------------------------------- 1 | // Package integration provides integration tests for dbee adapters. 2 | package integration 3 | -------------------------------------------------------------------------------- /dbee/tests/integration/mysql_integration_test.go: -------------------------------------------------------------------------------- 1 | package integration 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "testing" 7 | 8 | "github.com/kndndrj/nvim-dbee/dbee/core" 9 | th "github.com/kndndrj/nvim-dbee/dbee/tests/testhelpers" 10 | "github.com/stretchr/testify/assert" 11 | tsuite "github.com/stretchr/testify/suite" 12 | tc "github.com/testcontainers/testcontainers-go" 13 | ) 14 | 15 | // MySQLTestSuite is the test suite for the mysql adapter. 16 | type MySQLTestSuite struct { 17 | tsuite.Suite 18 | ctr *th.MySQLContainer 19 | ctx context.Context 20 | d *core.Connection 21 | } 22 | 23 | func TestMySQLTestSuite(t *testing.T) { 24 | tsuite.Run(t, new(MySQLTestSuite)) 25 | } 26 | 27 | func (suite *MySQLTestSuite) SetupSuite() { 28 | suite.ctx = context.Background() 29 | ctr, err := th.NewMySQLContainer(suite.ctx, &core.ConnectionParams{ 30 | ID: "test-mysql", 31 | Name: "test-mysql", 32 | }) 33 | if err != nil { 34 | log.Fatal(err) 35 | } 36 | 37 | suite.ctr = ctr 38 | suite.d = ctr.Driver // easier access to driver 39 | } 40 | 41 | func (suite *MySQLTestSuite) TeardownSuite() { 42 | tc.CleanupContainer(suite.T(), suite.ctr) 43 | } 44 | 45 | func (suite *MySQLTestSuite) TestShouldErrorInvalidQuery() { 46 | t := suite.T() 47 | 48 | want := "You have an error in your SQL syntax" 49 | 50 | call := suite.d.Execute("invalid sql", func(cs core.CallState, c *core.Call) { 51 | if cs == core.CallStateExecutingFailed { 52 | assert.ErrorContains(t, c.Err(), want) 53 | } 54 | }) 55 | assert.NotNil(t, call) 56 | } 57 | 58 | func (suite *MySQLTestSuite) TestShouldCancelQuery() { 59 | t := suite.T() 60 | want := []core.CallState{core.CallStateExecuting, core.CallStateCanceled} 61 | 62 | _, got, err := th.GetResultWithCancel(t, suite.d, "SELECT 1") 63 | assert.NoError(t, err) 64 | 65 | assert.Equal(t, want, got) 66 | } 67 | 68 | func (suite *MySQLTestSuite) TestShouldReturnManyRows() { 69 | t := suite.T() 70 | 71 | wantStates := []core.CallState{ 72 | core.CallStateExecuting, core.CallStateRetrieving, core.CallStateArchived, 73 | } 74 | wantCols := []string{"id", "username", "email"} 75 | wantRows := []core.Row{ 76 | {"1", "john_doe", "john@example.com"}, 77 | {"2", "jane_smith", "jane@example.com"}, 78 | {"3", "bob_wilson", "bob@example.com"}, 79 | } 80 | 81 | query := "SELECT * FROM test.test_table" 82 | 83 | gotRows, gotCols, gotStates, err := th.GetResult(t, suite.d, query) 84 | assert.NoError(t, err) 85 | 86 | assert.ElementsMatch(t, wantCols, gotCols) 87 | assert.ElementsMatch(t, wantStates, gotStates) 88 | assert.Equal(t, wantRows, gotRows) 89 | } 90 | 91 | func (suite *MySQLTestSuite) TestShouldReturnSingleRows() { 92 | t := suite.T() 93 | 94 | wantStates := []core.CallState{ 95 | core.CallStateExecuting, core.CallStateRetrieving, core.CallStateArchived, 96 | } 97 | wantCols := []string{"id", "username", "email"} 98 | wantRows := []core.Row{ 99 | {"2", "jane_smith", "jane@example.com"}, 100 | } 101 | 102 | query := "SELECT * FROM test.test_view" 103 | 104 | gotRows, gotCols, gotStates, err := th.GetResult(t, suite.d, query) 105 | assert.NoError(t, err) 106 | 107 | assert.ElementsMatch(t, wantCols, gotCols) 108 | assert.ElementsMatch(t, wantStates, gotStates) 109 | assert.Equal(t, wantRows, gotRows) 110 | } 111 | 112 | func (suite *MySQLTestSuite) TestShouldReturnStructure() { 113 | t := suite.T() 114 | 115 | wantSchemas := []string{"information_schema", "mysql", "performance_schema", "sys", "test"} 116 | wantSomeTable := "test_table" 117 | 118 | structure, err := suite.d.GetStructure() 119 | assert.NoError(t, err) 120 | 121 | gotSchemas := th.GetSchemas(t, structure) 122 | assert.ElementsMatch(t, wantSchemas, gotSchemas) 123 | 124 | gotTables := th.GetModels(t, structure, core.StructureTypeTable) 125 | assert.Contains(t, gotTables, wantSomeTable) 126 | } 127 | 128 | func (suite *MySQLTestSuite) TestShouldReturnColumns() { 129 | t := suite.T() 130 | 131 | want := []*core.Column{ 132 | {Name: "id", Type: "int unsigned"}, 133 | {Name: "username", Type: "varchar(255)"}, 134 | {Name: "email", Type: "varchar(255)"}, 135 | } 136 | 137 | got, err := suite.d.GetColumns(&core.TableOptions{ 138 | Table: "test_table", 139 | Schema: "test", 140 | Materialization: core.StructureTypeTable, 141 | }) 142 | 143 | assert.NoError(t, err) 144 | assert.Equal(t, want, got) 145 | } 146 | -------------------------------------------------------------------------------- /dbee/tests/integration/oracle_integration_test.go: -------------------------------------------------------------------------------- 1 | package integration 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "testing" 7 | 8 | "github.com/kndndrj/nvim-dbee/dbee/core" 9 | th "github.com/kndndrj/nvim-dbee/dbee/tests/testhelpers" 10 | "github.com/stretchr/testify/assert" 11 | tsuite "github.com/stretchr/testify/suite" 12 | tc "github.com/testcontainers/testcontainers-go" 13 | ) 14 | 15 | // OracleTestSuite is the test suite for the oracle adapter. 16 | type OracleTestSuite struct { 17 | tsuite.Suite 18 | ctr *th.OracleContainer 19 | ctx context.Context 20 | d *core.Connection 21 | } 22 | 23 | func TestOracleTestSuite(t *testing.T) { 24 | tsuite.Run(t, new(OracleTestSuite)) 25 | } 26 | 27 | func (suite *OracleTestSuite) SetupSuite() { 28 | suite.ctx = context.Background() 29 | ctr, err := th.NewOracleContainer(suite.ctx, &core.ConnectionParams{ 30 | ID: "test-oracle", 31 | Name: "test-oracle", 32 | }) 33 | if err != nil { 34 | log.Fatal(err) 35 | } 36 | 37 | suite.ctr = ctr 38 | suite.d = ctr.Driver 39 | } 40 | 41 | func (suite *OracleTestSuite) TeardownSuite() { 42 | tc.CleanupContainer(suite.T(), suite.ctr) 43 | } 44 | 45 | func (suite *OracleTestSuite) TestShouldErrorInvalidQuery() { 46 | t := suite.T() 47 | 48 | want := "ORA-00900: invalid SQL statement" 49 | 50 | call := suite.d.Execute("invalid sql", func(cs core.CallState, c *core.Call) { 51 | if cs == core.CallStateExecutingFailed { 52 | assert.ErrorContains(t, c.Err(), want) 53 | } 54 | }) 55 | assert.NotNil(t, call) 56 | } 57 | 58 | func (suite *OracleTestSuite) TestShouldCancelQuery() { 59 | t := suite.T() 60 | want := []core.CallState{core.CallStateExecuting, core.CallStateCanceled} 61 | 62 | _, got, err := th.GetResultWithCancel(t, suite.d, "SELECT 1") 63 | assert.NoError(t, err) 64 | 65 | assert.Equal(t, want, got) 66 | } 67 | 68 | func (suite *OracleTestSuite) TestShouldReturnManyRows() { 69 | t := suite.T() 70 | 71 | wantStates := []core.CallState{ 72 | core.CallStateExecuting, core.CallStateRetrieving, core.CallStateArchived, 73 | } 74 | wantCols := []string{"ID", "USERNAME"} 75 | wantRows := []core.Row{ 76 | {"1", "john_doe"}, 77 | {"2", "jane_smith"}, 78 | {"3", "bob_wilson"}, 79 | } 80 | 81 | query := "SELECT ID, USERNAME FROM test_table" 82 | 83 | gotRows, gotCols, gotStates, err := th.GetResult(t, suite.d, query) 84 | assert.NoError(t, err) 85 | 86 | assert.ElementsMatch(t, wantCols, gotCols) 87 | assert.ElementsMatch(t, wantStates, gotStates) 88 | assert.Equal(t, wantRows, gotRows) 89 | } 90 | 91 | func (suite *OracleTestSuite) TestShouldReturnOneRow() { 92 | t := suite.T() 93 | 94 | wantStates := []core.CallState{ 95 | core.CallStateExecuting, core.CallStateRetrieving, core.CallStateArchived, 96 | } 97 | wantCols := []string{"ID", "USERNAME"} 98 | wantRows := []core.Row{{"2", "jane_smith"}} 99 | 100 | query := "SELECT ID, USERNAME FROM test_view" 101 | 102 | gotRows, gotCols, gotStates, err := th.GetResult(t, suite.d, query) 103 | assert.NoError(t, err) 104 | 105 | assert.ElementsMatch(t, wantCols, gotCols) 106 | assert.ElementsMatch(t, wantStates, gotStates) 107 | assert.Equal(t, wantRows, gotRows) 108 | } 109 | 110 | func (suite *OracleTestSuite) TestShouldReturnStructure() { 111 | t := suite.T() 112 | 113 | var ( 114 | wantSomeSchema = "TESTER" 115 | wantSomeTable = "TEST_TABLE" 116 | wantSomeView = "TEST_VIEW" 117 | ) 118 | 119 | structure, err := suite.d.GetStructure() 120 | assert.NoError(t, err) 121 | 122 | gotSchemas := th.GetSchemas(t, structure) 123 | assert.Contains(t, gotSchemas, wantSomeSchema) 124 | 125 | gotTables := th.GetModels(t, structure, core.StructureTypeTable) 126 | assert.Contains(t, gotTables, wantSomeTable) 127 | 128 | gotViews := th.GetModels(t, structure, core.StructureTypeView) 129 | assert.Contains(t, gotViews, wantSomeView) 130 | } 131 | 132 | func (suite *OracleTestSuite) TestShouldReturnColumns() { 133 | t := suite.T() 134 | 135 | want := []*core.Column{ 136 | {Name: "ID", Type: "NUMBER"}, 137 | {Name: "USERNAME", Type: "VARCHAR2"}, 138 | {Name: "EMAIL", Type: "VARCHAR2"}, 139 | } 140 | 141 | got, err := suite.d.GetColumns(&core.TableOptions{ 142 | Table: "TEST_TABLE", 143 | Schema: "TESTER", 144 | Materialization: core.StructureTypeTable, 145 | }) 146 | 147 | assert.NoError(t, err) 148 | assert.Equal(t, want, got) 149 | } 150 | -------------------------------------------------------------------------------- /dbee/tests/integration/sqlite_integration_test.go: -------------------------------------------------------------------------------- 1 | package integration 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "testing" 7 | 8 | "github.com/kndndrj/nvim-dbee/dbee/core" 9 | th "github.com/kndndrj/nvim-dbee/dbee/tests/testhelpers" 10 | "github.com/stretchr/testify/assert" 11 | tsuite "github.com/stretchr/testify/suite" 12 | tc "github.com/testcontainers/testcontainers-go" 13 | ) 14 | 15 | // SQLiteTestSuite is the test suite for the sqlite adapter. 16 | type SQLiteTestSuite struct { 17 | tsuite.Suite 18 | ctr *th.SQLiteContainer 19 | ctx context.Context 20 | d *core.Connection 21 | } 22 | 23 | func TestSQLiteTestSuite(t *testing.T) { 24 | tsuite.Run(t, new(SQLiteTestSuite)) 25 | } 26 | 27 | func (suite *SQLiteTestSuite) SetupSuite() { 28 | suite.ctx = context.Background() 29 | tempDir := suite.T().TempDir() 30 | 31 | params := &core.ConnectionParams{ID: "test-sqlite", Name: "test-sqlite"} 32 | ctr, err := th.NewSQLiteContainer(suite.ctx, params, tempDir) 33 | if err != nil { 34 | log.Fatal(err) 35 | } 36 | 37 | suite.ctr, suite.d = ctr, ctr.Driver 38 | } 39 | 40 | func (suite *SQLiteTestSuite) TeardownSuite() { 41 | tc.CleanupContainer(suite.T(), suite.ctr) 42 | } 43 | 44 | func (suite *SQLiteTestSuite) TestShouldErrorInvalidQuery() { 45 | t := suite.T() 46 | 47 | want := "syntax error" 48 | 49 | call := suite.d.Execute("invalid sql", func(cs core.CallState, c *core.Call) { 50 | if cs == core.CallStateExecutingFailed { 51 | assert.ErrorContains(t, c.Err(), want) 52 | } 53 | }) 54 | assert.NotNil(t, call) 55 | } 56 | 57 | func (suite *SQLiteTestSuite) TestShouldCancelQuery() { 58 | t := suite.T() 59 | want := []core.CallState{core.CallStateExecuting, core.CallStateCanceled} 60 | 61 | _, got, err := th.GetResultWithCancel(t, suite.d, "SELECT 1") 62 | assert.NoError(t, err) 63 | 64 | assert.Equal(t, want, got) 65 | } 66 | 67 | func (suite *SQLiteTestSuite) TestShouldReturnManyRows() { 68 | t := suite.T() 69 | 70 | wantStates := []core.CallState{ 71 | core.CallStateExecuting, core.CallStateRetrieving, core.CallStateArchived, 72 | } 73 | wantCols := []string{"id", "username"} 74 | wantRows := []core.Row{ 75 | {int64(1), "john_doe"}, 76 | {int64(2), "jane_smith"}, 77 | {int64(3), "bob_wilson"}, 78 | } 79 | 80 | query := "SELECT id, username FROM test_table" 81 | 82 | gotRows, gotCols, gotStates, err := th.GetResult(t, suite.d, query) 83 | assert.NoError(t, err) 84 | 85 | assert.ElementsMatch(t, wantCols, gotCols) 86 | assert.ElementsMatch(t, wantStates, gotStates) 87 | assert.Equal(t, wantRows, gotRows) 88 | } 89 | 90 | func (suite *SQLiteTestSuite) TestShouldReturnOneRow() { 91 | t := suite.T() 92 | 93 | wantStates := []core.CallState{ 94 | core.CallStateExecuting, core.CallStateRetrieving, core.CallStateArchived, 95 | } 96 | wantCols := []string{"id", "username"} 97 | wantRows := []core.Row{{int64(2), "jane_smith"}} 98 | 99 | query := "SELECT id, username FROM test_view" 100 | 101 | gotRows, gotCols, gotStates, err := th.GetResult(t, suite.d, query) 102 | assert.NoError(t, err) 103 | 104 | assert.ElementsMatch(t, wantCols, gotCols) 105 | assert.ElementsMatch(t, wantStates, gotStates) 106 | assert.Equal(t, wantRows, gotRows) 107 | } 108 | 109 | func (suite *SQLiteTestSuite) TestShouldReturnStructure() { 110 | t := suite.T() 111 | 112 | var ( 113 | wantSchema = "sqlite_schema" 114 | wantSomeTable = "test_table" 115 | wantSomeView = "test_view" 116 | ) 117 | 118 | structure, err := suite.d.GetStructure() 119 | assert.NoError(t, err) 120 | 121 | gotSchemas := th.GetSchemas(t, structure) 122 | assert.Contains(t, gotSchemas, wantSchema) 123 | 124 | gotTables := th.GetModels(t, structure, core.StructureTypeTable) 125 | assert.Contains(t, gotTables, wantSomeTable) 126 | 127 | gotViews := th.GetModels(t, structure, core.StructureTypeView) 128 | assert.Contains(t, gotViews, wantSomeView) 129 | } 130 | 131 | func (suite *SQLiteTestSuite) TestShouldReturnColumns() { 132 | t := suite.T() 133 | 134 | want := []*core.Column{ 135 | {Name: "id", Type: "INTEGER"}, 136 | {Name: "username", Type: "TEXT"}, 137 | {Name: "email", Type: "TEXT"}, 138 | } 139 | 140 | got, err := suite.d.GetColumns(&core.TableOptions{ 141 | Table: "test_table", 142 | Schema: "sqlite_schema", 143 | Materialization: core.StructureTypeTable, 144 | }) 145 | 146 | assert.NoError(t, err) 147 | assert.Equal(t, want, got) 148 | } 149 | 150 | func (suite *SQLiteTestSuite) TestShouldNoOperationSwitchDatabase() { 151 | t := suite.T() 152 | 153 | driver, err := suite.ctr.NewDriver(&core.ConnectionParams{ 154 | ID: "test-sqlite-2", 155 | Name: "test-sqlite-2", 156 | }) 157 | assert.NoError(t, err) 158 | 159 | err = driver.SelectDatabase("no-op") 160 | assert.Nil(t, err) 161 | } 162 | -------------------------------------------------------------------------------- /dbee/tests/testdata/bigquery_seed.yaml: -------------------------------------------------------------------------------- 1 | # https://golang.testcontainers.org/modules/gcloud/#data-yaml-seed-file 2 | projects: 3 | - id: test-project 4 | datasets: 5 | - id: dataset_test 6 | tables: 7 | - id: table_test 8 | columns: 9 | - name: id 10 | type: INTEGER 11 | - name: name 12 | type: STRING 13 | - name: createdAt 14 | type: TIMESTAMP 15 | data: 16 | - id: 1 17 | name: john 18 | createdAt: "2025-01-21T00:00:00" 19 | - id: 2 20 | name: bob 21 | createdAt: "2025-01-21T00:01:00" 22 | - id: dataset_test.INFORMATION_SCHEMA 23 | tables: 24 | - id: COLUMNS 25 | columns: 26 | - name: TABLE_SCHEMA 27 | type: STRING 28 | - name: TABLE_NAME 29 | type: STRING 30 | - name: COLUMN_NAME 31 | type: STRING 32 | - name: DATA_TYPE 33 | type: STRING 34 | data: 35 | - TABLE_SCHEMA: dataset_test 36 | TABLE_NAME: table_test 37 | COLUMN_NAME: id 38 | DATA_TYPE: INTEGER 39 | - TABLE_SCHEMA: dataset_test 40 | TABLE_NAME: table_test 41 | COLUMN_NAME: name 42 | DATA_TYPE: STRING 43 | - TABLE_SCHEMA: dataset_test 44 | TABLE_NAME: table_test 45 | COLUMN_NAME: createdAt 46 | DATA_TYPE: TIMESTAMP 47 | 48 | - id: test-project2 49 | datasets: [] 50 | -------------------------------------------------------------------------------- /dbee/tests/testdata/clickhouse_seed.sql: -------------------------------------------------------------------------------- 1 | CREATE DATABASE IF NOT EXISTS test; 2 | 3 | CREATE TABLE IF NOT EXISTS test.test_table 4 | ( 5 | id UInt32, 6 | username String, 7 | email String, 8 | created_at DateTime, 9 | is_active UInt8 10 | ) ENGINE = MergeTree() 11 | ORDER BY id 12 | ; 13 | 14 | INSERT INTO test.test_table (id, username, email, created_at, is_active) VALUES 15 | (1, 'john_doe', 'john@example.com', '2023-01-01 10:00:00', 1), 16 | (2, 'jane_smith', 'jane@example.com', '2023-01-02 11:30:00', 1), 17 | (3, 'bob_wilson', 'bob@example.com', '2023-01-03 09:15:00', 0) 18 | ; 19 | 20 | CREATE VIEW IF NOT EXISTS test.test_view AS 21 | SELECT id, username, email, created_at 22 | FROM test.test_table 23 | WHERE is_active = 1 24 | ; 25 | 26 | -------------------------------------------------------------------------------- /dbee/tests/testdata/duckdb_seed.sql: -------------------------------------------------------------------------------- 1 | CREATE SCHEMA IF NOT EXISTS test_container.test_schema; 2 | 3 | CREATE TABLE IF NOT EXISTS test_container.test_schema.test_table ( 4 | id INTEGER PRIMARY KEY, 5 | username TEXT NOT NULL, 6 | email TEXT NOT NULL, 7 | created_at TIMESTAMP NOT NULL 8 | ); 9 | 10 | INSERT INTO test_container.test_schema.test_table (id, username, email, created_at) VALUES 11 | (1, 'john_doe', 'john@example.com', '2023-01-01 10:00:00'), 12 | (2, 'jane_smith', 'jane@example.com', '2023-01-02 11:30:00'), 13 | (3, 'bob_wilson', 'bob@example.com', '2023-01-03 09:15:00'); 14 | 15 | CREATE OR REPLACE VIEW test_container.test_schema.test_view AS 16 | SELECT id, username, email 17 | FROM test_container.test_schema.test_table 18 | WHERE id = 2; 19 | -------------------------------------------------------------------------------- /dbee/tests/testdata/mysql_seed.sql: -------------------------------------------------------------------------------- 1 | CREATE SCHEMA IF NOT EXISTS test; 2 | 3 | CREATE TABLE IF NOT EXISTS test.test_table ( 4 | id INT UNSIGNED, 5 | username VARCHAR(255), 6 | email VARCHAR(255), 7 | PRIMARY KEY (id) 8 | ); 9 | 10 | INSERT INTO test.test_table (id, username, email) VALUES 11 | (1, 'john_doe', 'john@example.com'), 12 | (2, 'jane_smith', 'jane@example.com'), 13 | (3, 'bob_wilson', 'bob@example.com'); 14 | 15 | CREATE OR REPLACE VIEW test.test_view AS 16 | SELECT id, username, email 17 | FROM test.test_table 18 | WHERE id = 2; 19 | 20 | -------------------------------------------------------------------------------- /dbee/tests/testdata/oracle_seed.sql: -------------------------------------------------------------------------------- 1 | -- Connect as system to grant privileges 2 | ALTER SESSION SET CONTAINER = FREEPDB1; 3 | grant create session, create table, create view, unlimited tablespace 4 | to tester 5 | ; 6 | 7 | -- Must match the APP_USER env in testcontainer 8 | ALTER SESSION SET CURRENT_SCHEMA = tester; 9 | 10 | CREATE TABLE test_table ( 11 | id NUMBER, 12 | username VARCHAR2(255), 13 | email VARCHAR2(255), 14 | CONSTRAINT test_table_pk PRIMARY KEY (id) 15 | ); 16 | 17 | INSERT INTO test_table (id, username, email) VALUES 18 | (1, 'john_doe', 'john@example.com'); 19 | INSERT INTO test_table (id, username, email) VALUES 20 | (2, 'jane_smith', 'jane@example.com'); 21 | INSERT INTO test_table (id, username, email) VALUES 22 | (3, 'bob_wilson', 'bob@example.com'); 23 | 24 | CREATE OR REPLACE VIEW test_view AS 25 | SELECT id, username, email 26 | FROM test_table 27 | WHERE id = 2; 28 | 29 | commit 30 | ; 31 | 32 | -------------------------------------------------------------------------------- /dbee/tests/testdata/postgres_seed.sql: -------------------------------------------------------------------------------- 1 | 2 | CREATE SCHEMA IF NOT EXISTS test; 3 | 4 | CREATE TABLE IF NOT EXISTS test.test_table ( 5 | id INT, 6 | username VARCHAR(255), 7 | email VARCHAR(255), 8 | PRIMARY KEY (id) 9 | ); 10 | 11 | INSERT INTO test.test_table (id, username, email) VALUES 12 | (1, 'john_doe', 'john@example.com'), 13 | (2, 'jane_smith', 'jane@example.com'), 14 | (3, 'bob_wilson', 'bob@example.com'); 15 | 16 | CREATE OR REPLACE VIEW test.test_view AS ( 17 | SELECT id, username, email 18 | FROM test.test_table 19 | WHERE id = 2 20 | ); 21 | 22 | -------------------------------------------------------------------------------- /dbee/tests/testdata/sqlite_seed.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE IF NOT EXISTS test_table ( 2 | id INTEGER PRIMARY KEY, 3 | username TEXT, 4 | email TEXT 5 | ); 6 | 7 | INSERT INTO test_table (id, username, email) VALUES 8 | (1, 'john_doe', 'john@example.com'), 9 | (2, 'jane_smith', 'jane@example.com'), 10 | (3, 'bob_wilson', 'bob@example.com'); 11 | 12 | CREATE VIEW IF NOT EXISTS test_view AS 13 | SELECT id, username, email 14 | FROM test_table 15 | WHERE id = 2; 16 | 17 | -------------------------------------------------------------------------------- /dbee/tests/testdata/sqlserver_seed.sql: -------------------------------------------------------------------------------- 1 | /* 2 | Each transaction need to be separated by GO, 3 | see more how t-sql works: 4 | http://learn.microsoft.com/en-us/sql/linux/sql-server-linux-docker-container-deployment?view=sql-server-2017&pivots=cs1-bash 5 | */ 6 | 7 | CREATE SCHEMA test_schema; 8 | GO 9 | 10 | CREATE TABLE test_schema.test_table ( 11 | ID INT PRIMARY KEY IDENTITY, 12 | Name NVARCHAR(100), 13 | Email NVARCHAR(100) UNIQUE 14 | ); 15 | GO 16 | 17 | INSERT INTO test_schema.test_table (Name, Email) VALUES 18 | ('Alice', 'alice@example.com'), 19 | ('Bob', 'bob@example.com') 20 | ; 21 | GO 22 | 23 | CREATE VIEW test_schema.test_view AS ( 24 | SELECT * FROM test_schema.test_table WHERE Name = 'Bob' 25 | ); 26 | GO 27 | 28 | CREATE DATABASE dev; 29 | GO 30 | -------------------------------------------------------------------------------- /dbee/tests/testhelpers/bigquery.go: -------------------------------------------------------------------------------- 1 | package testhelpers 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/kndndrj/nvim-dbee/dbee/adapters" 8 | "github.com/kndndrj/nvim-dbee/dbee/core" 9 | tc "github.com/testcontainers/testcontainers-go" 10 | "github.com/testcontainers/testcontainers-go/modules/gcloud" 11 | "github.com/testcontainers/testcontainers-go/wait" 12 | ) 13 | 14 | // BigQueryContainer is a test container for BigQuery. 15 | type BigQueryContainer struct { 16 | *gcloud.GCloudContainer 17 | ConnURL string 18 | Driver *core.Connection 19 | } 20 | 21 | // NewBigQueryContainer creates a new BigQuery container with 22 | // default adapter and connection. The params.URL is overwritten. 23 | func NewBigQueryContainer(ctx context.Context, params *core.ConnectionParams) (*BigQueryContainer, error) { 24 | seedFile, err := GetTestDataFile("bigquery_seed.yaml") 25 | if err != nil { 26 | return nil, err 27 | } 28 | 29 | ctr, err := gcloud.RunBigQuery( 30 | ctx, 31 | "ghcr.io/goccy/bigquery-emulator:0.6.6", 32 | gcloud.WithProjectID("test-project"), 33 | gcloud.WithDataYAML(seedFile), 34 | tc.CustomizeRequest(tc.GenericContainerRequest{ 35 | ProviderType: GetContainerProvider(), 36 | ContainerRequest: tc.ContainerRequest{ 37 | ImagePlatform: "linux/amd64", 38 | }, 39 | }), 40 | tc.WithWaitStrategy(wait.ForLog("[bigquery-emulator] gRPC")), 41 | ) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | connURL := fmt.Sprintf("bigquery://%s?max-bytes-billed=1000&disable-query-cache=true&endpoint=%s", ctr.Settings.ProjectID, ctr.URI) 47 | if params.Type == "" { 48 | params.Type = "bigquery" 49 | } 50 | 51 | if params.URL == "" { 52 | params.URL = connURL 53 | } 54 | 55 | driver, err := adapters.NewConnection(params) 56 | if err != nil { 57 | return nil, err 58 | } 59 | 60 | return &BigQueryContainer{ 61 | GCloudContainer: ctr, 62 | ConnURL: connURL, 63 | Driver: driver, 64 | }, nil 65 | } 66 | 67 | // NewDriver helper function to create a new driver with the connection URL. 68 | func (p *BigQueryContainer) NewDriver(params *core.ConnectionParams) (*core.Connection, error) { 69 | if params.URL == "" { 70 | params.URL = p.ConnURL 71 | } 72 | if params.Type == "" { 73 | params.Type = "bigquery" 74 | } 75 | return adapters.NewConnection(params) 76 | } 77 | -------------------------------------------------------------------------------- /dbee/tests/testhelpers/clickhouse.go: -------------------------------------------------------------------------------- 1 | package testhelpers 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/kndndrj/nvim-dbee/dbee/adapters" 7 | "github.com/kndndrj/nvim-dbee/dbee/core" 8 | tc "github.com/testcontainers/testcontainers-go" 9 | "github.com/testcontainers/testcontainers-go/modules/clickhouse" 10 | ) 11 | 12 | type ClickHouseContainer struct { 13 | *clickhouse.ClickHouseContainer 14 | ConnURL string 15 | Driver *core.Connection 16 | } 17 | 18 | // NewClickHouseContainer creates a new clickhouse container with 19 | // default adapter and connection. The params.URL is overwritten. 20 | func NewClickHouseContainer(ctx context.Context, params *core.ConnectionParams) (*ClickHouseContainer, error) { 21 | seedFile, err := GetTestDataFile("clickhouse_seed.sql") 22 | if err != nil { 23 | return nil, err 24 | } 25 | 26 | ctr, err := clickhouse.Run( 27 | ctx, 28 | "clickhouse/clickhouse-server:25.1-alpine", 29 | tc.CustomizeRequest(tc.GenericContainerRequest{ 30 | ProviderType: GetContainerProvider(), 31 | }), 32 | clickhouse.WithUsername("admin"), 33 | clickhouse.WithPassword(""), 34 | clickhouse.WithDatabase("dev"), 35 | clickhouse.WithInitScripts(seedFile.Name()), 36 | ) 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | connURL, err := ctr.ConnectionString(ctx) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | if params.Type == "" { 47 | params.Type = "clickhouse" 48 | } 49 | 50 | if params.URL == "" { 51 | params.URL = connURL 52 | } 53 | 54 | driver, err := adapters.NewConnection(params) 55 | if err != nil { 56 | return nil, err 57 | } 58 | 59 | return &ClickHouseContainer{ 60 | ClickHouseContainer: ctr, 61 | ConnURL: connURL, 62 | Driver: driver, 63 | }, nil 64 | } 65 | 66 | // NewDriver helper function to create a new driver with the connection URL. 67 | func (p *ClickHouseContainer) NewDriver(params *core.ConnectionParams) (*core.Connection, error) { 68 | if params.URL == "" { 69 | params.URL = p.ConnURL 70 | } 71 | if params.Type == "" { 72 | params.Type = "clickhouse" 73 | } 74 | 75 | return adapters.NewConnection(params) 76 | } 77 | -------------------------------------------------------------------------------- /dbee/tests/testhelpers/duckdb.go: -------------------------------------------------------------------------------- 1 | package testhelpers 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "path/filepath" 7 | "strings" 8 | "time" 9 | 10 | "github.com/docker/docker/api/types/container" 11 | "github.com/kndndrj/nvim-dbee/dbee/adapters" 12 | "github.com/kndndrj/nvim-dbee/dbee/core" 13 | tc "github.com/testcontainers/testcontainers-go" 14 | "github.com/testcontainers/testcontainers-go/wait" 15 | ) 16 | 17 | // DuckDBContainer represents an in-memory DuckDB instance. 18 | type DuckDBContainer struct { 19 | tc.Container 20 | ConnURL string 21 | Driver *core.Connection 22 | TempDir string 23 | } 24 | 25 | // NewDuckDBContainer creates a new duckdb container with 26 | // default adapter and connection. The params.URL is overwritten. 27 | // It uses a temporary directory (usually the test suite tempDir) to store the db file. 28 | // The tmpDir is then mounted to the container and all the dependencies are installed 29 | // in the container file, while still being able to connect to the db file in the host. 30 | func NewDuckDBContainer(ctx context.Context, params *core.ConnectionParams, tmpDir string) (*DuckDBContainer, error) { 31 | seedFile, err := GetTestDataFile("duckdb_seed.sql") 32 | if err != nil { 33 | return nil, err 34 | } 35 | 36 | dbName, containerDBPath := "test_container.db", "/container/db" 37 | entrypointCmd := []string{ 38 | "apt-get update", 39 | "apt-get install -y curl", 40 | "curl https://install.duckdb.org | sh", 41 | "export PATH='/root/.duckdb/cli/latest':$PATH", 42 | fmt.Sprintf("duckdb %s/%s < %s", containerDBPath, dbName, seedFile.Name()), 43 | "echo 'ready'", 44 | "tail -f /dev/null", // hack to keep the container running indefinitely 45 | } 46 | 47 | req := tc.ContainerRequest{ 48 | Image: "debian:12.10-slim", 49 | Files: []tc.ContainerFile{ 50 | { 51 | Reader: seedFile, 52 | ContainerFilePath: seedFile.Name(), 53 | FileMode: 0o755, 54 | }, 55 | }, 56 | HostConfigModifier: func(hc *container.HostConfig) { 57 | hc.Binds = append(hc.Binds, fmt.Sprintf("%s:%s", tmpDir, containerDBPath)) 58 | }, 59 | Cmd: []string{"sh", "-c", strings.Join(entrypointCmd, " && ")}, 60 | WaitingFor: wait.ForLog("ready").WithStartupTimeout(60 * time.Second), 61 | } 62 | 63 | ctr, err := tc.GenericContainer(ctx, tc.GenericContainerRequest{ 64 | ContainerRequest: req, 65 | ProviderType: GetContainerProvider(), 66 | Started: true, 67 | }) 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | if params.Type == "" { 73 | params.Type = "duckdb" 74 | } 75 | connURL := filepath.Join(tmpDir, dbName) 76 | if params.URL == "" { 77 | params.URL = connURL 78 | } 79 | 80 | driver, err := adapters.NewConnection(params) 81 | if err != nil { 82 | return nil, err 83 | } 84 | 85 | return &DuckDBContainer{ 86 | Container: ctr, 87 | ConnURL: connURL, 88 | Driver: driver, 89 | TempDir: tmpDir, 90 | }, nil 91 | } 92 | 93 | // NewDriver helper function to create a new driver with the connection URL. 94 | func (d *DuckDBContainer) NewDriver(params *core.ConnectionParams) (*core.Connection, error) { 95 | if params.Type == "" { 96 | params.Type = "duckdb" 97 | } 98 | if params.URL != "" { 99 | params.URL = d.ConnURL 100 | } 101 | 102 | return adapters.NewConnection(params) 103 | } 104 | -------------------------------------------------------------------------------- /dbee/tests/testhelpers/helper.go: -------------------------------------------------------------------------------- 1 | // Package testhelpers provides helpers for integration tests. 2 | package testhelpers 3 | 4 | import ( 5 | "fmt" 6 | "os" 7 | "os/exec" 8 | "path/filepath" 9 | "runtime" 10 | "testing" 11 | "time" 12 | 13 | "github.com/kndndrj/nvim-dbee/dbee/core" 14 | "github.com/stretchr/testify/require" 15 | "github.com/testcontainers/testcontainers-go" 16 | ) 17 | 18 | const ( 19 | // eventBufferTime is a padding to let events come through (e.g. archived) 20 | eventBufferTime = 100 * time.Millisecond 21 | // eventTimeout is the maximum time to wait for an event to come through 22 | eventTimeout = 10 * time.Second 23 | ) 24 | 25 | // errTimeOut is an error for when an event did not finish within the expected time. 26 | var errTimeOut = fmt.Errorf("event did not finish within %v", eventTimeout) 27 | 28 | // GetContainerProvider returns the container provider type to use for the tests. 29 | // If we detect podman is available, we use it, otherwise we use docker. 30 | func GetContainerProvider() testcontainers.ProviderType { 31 | if _, err := exec.LookPath("podman"); err == nil { 32 | fmt.Println("Podman detected. Remember to set TESTCONTAINERS_RYUK_CONTAINER_PRIVILEGED=true;") 33 | return testcontainers.ProviderPodman 34 | } 35 | return testcontainers.ProviderDocker 36 | } 37 | 38 | // GetResult is a helper function for calling the Execute method on a driver 39 | // and waiting for the result to be available. 40 | func GetResult(t *testing.T, d *core.Connection, query string) ([]core.Row, core.Header, []core.CallState, error) { 41 | t.Helper() 42 | 43 | var result *core.Result 44 | outStates := make([]core.CallState, 0) 45 | outRows := make([]core.Row, 0) 46 | 47 | call := d.Execute(query, func(state core.CallState, c *core.Call) { 48 | outStates = append(outStates, state) 49 | 50 | var err error 51 | if state == core.CallStateArchived || state == core.CallStateRetrieving { 52 | result, err = c.GetResult() 53 | require.NoError(t, err, "failed getting result with %s, err: %s", state, c.Err()) 54 | outRows, err = result.Rows(0, result.Len()) 55 | require.NoError(t, err, "failed getting rows with %s, err: %s", state, c.Err()) 56 | } 57 | }) 58 | 59 | select { 60 | case <-call.Done(): 61 | time.Sleep(eventBufferTime) 62 | require.NotNil(t, result, call.Err()) 63 | return outRows, result.Header(), outStates, nil 64 | 65 | case <-time.After(eventTimeout): 66 | return nil, nil, nil, errTimeOut 67 | } 68 | } 69 | 70 | // GetResultWithCancel is a helper function for calling the Execute method on a driver 71 | // and canceling the call after the first state is received. 72 | func GetResultWithCancel(t *testing.T, d *core.Connection, query string) (*core.Result, []core.CallState, error) { 73 | t.Helper() 74 | 75 | var ( 76 | outResult *core.Result 77 | outErr error 78 | ) 79 | outStates := make([]core.CallState, 0) 80 | 81 | call := d.Execute(query, func(cs core.CallState, c *core.Call) { 82 | outStates = append(outStates, cs) 83 | c.Cancel() 84 | }) 85 | 86 | select { 87 | case <-call.Done(): 88 | time.Sleep(eventBufferTime) 89 | return outResult, outStates, outErr 90 | case <-time.After(eventTimeout): 91 | return nil, nil, errTimeOut 92 | } 93 | } 94 | 95 | // GetSchemas returns a list of schema names from the given structure. 96 | func GetSchemas(t *testing.T, structure []*core.Structure) []string { 97 | t.Helper() 98 | 99 | schemas := make([]string, 0) 100 | for _, s := range structure { 101 | if s.Name == s.Schema { 102 | schemas = append(schemas, s.Name) 103 | continue 104 | } 105 | } 106 | return schemas 107 | } 108 | 109 | // GetModels returns a list of model names (views, table, etc) from the given structure. 110 | func GetModels(t *testing.T, structure []*core.Structure, modelType core.StructureType) []string { 111 | t.Helper() 112 | 113 | out := make([]string, 0) 114 | for _, s := range structure { 115 | for _, c := range s.Children { 116 | if c.Type == modelType { 117 | out = append(out, c.Name) 118 | continue 119 | } 120 | } 121 | } 122 | return out 123 | } 124 | 125 | // GetTestDataPath returns the path to the testdata directory. 126 | func GetTestDataPath() (string, error) { 127 | _, currentFile, _, ok := runtime.Caller(0) 128 | if !ok { 129 | return "", fmt.Errorf("failed to get current file path") 130 | } 131 | 132 | return filepath.Join(filepath.Dir(currentFile), "../testdata"), nil 133 | } 134 | 135 | // GetTestDataFile returns a file from the testdata directory. 136 | func GetTestDataFile(filename string) (*os.File, error) { 137 | testDataPath, err := GetTestDataPath() 138 | if err != nil { 139 | return nil, err 140 | } 141 | 142 | path := filepath.Join(testDataPath, filename) 143 | return os.Open(path) 144 | } 145 | -------------------------------------------------------------------------------- /dbee/tests/testhelpers/mysql.go: -------------------------------------------------------------------------------- 1 | package testhelpers 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/kndndrj/nvim-dbee/dbee/adapters" 7 | "github.com/kndndrj/nvim-dbee/dbee/core" 8 | tc "github.com/testcontainers/testcontainers-go" 9 | tcmysql "github.com/testcontainers/testcontainers-go/modules/mysql" 10 | ) 11 | 12 | type MySQLContainer struct { 13 | *tcmysql.MySQLContainer 14 | ConnURL string 15 | Driver *core.Connection 16 | } 17 | 18 | // NewMySQLContainer creates a new MySQL container with 19 | // default adapter and connection. The params.URL is overwritten. 20 | func NewMySQLContainer(ctx context.Context, params *core.ConnectionParams) (*MySQLContainer, error) { 21 | seedFile, err := GetTestDataFile("mysql_seed.sql") 22 | if err != nil { 23 | return nil, err 24 | } 25 | 26 | ctr, err := tcmysql.Run( 27 | ctx, 28 | "mysql:9.2.0", 29 | tc.CustomizeRequest(tc.GenericContainerRequest{ 30 | ProviderType: GetContainerProvider(), 31 | }), 32 | tcmysql.WithDatabase("dev"), 33 | tcmysql.WithPassword("password"), 34 | tcmysql.WithUsername("root"), 35 | tcmysql.WithScripts(seedFile.Name()), 36 | ) 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | connURL, err := ctr.ConnectionString(ctx, "tls=skip-verify") 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | if params.Type == "" { 47 | params.Type = "mysql" 48 | } 49 | 50 | if params.URL == "" { 51 | params.URL = connURL 52 | } 53 | 54 | driver, err := adapters.NewConnection(params) 55 | if err != nil { 56 | return nil, err 57 | } 58 | 59 | return &MySQLContainer{ 60 | MySQLContainer: ctr, 61 | ConnURL: connURL, 62 | Driver: driver, 63 | }, nil 64 | } 65 | 66 | // NewDriver helper function to create a new driver with the connection URL. 67 | func (p *MySQLContainer) NewDriver(params *core.ConnectionParams) (*core.Connection, error) { 68 | if params.URL == "" { 69 | params.URL = p.ConnURL 70 | } 71 | if params.Type == "" { 72 | params.Type = "mysql" 73 | } 74 | 75 | return adapters.NewConnection(params) 76 | } 77 | -------------------------------------------------------------------------------- /dbee/tests/testhelpers/oracle.go: -------------------------------------------------------------------------------- 1 | package testhelpers 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/docker/docker/api/types/container" 9 | "github.com/kndndrj/nvim-dbee/dbee/adapters" 10 | "github.com/kndndrj/nvim-dbee/dbee/core" 11 | tc "github.com/testcontainers/testcontainers-go" 12 | "github.com/testcontainers/testcontainers-go/wait" 13 | ) 14 | 15 | type OracleContainer struct { 16 | tc.Container 17 | ConnURL string 18 | Driver *core.Connection 19 | } 20 | 21 | // NewOracleContainer creates a new oracle container with 22 | // default adapter and connection. The params.URL is overwritten. 23 | func NewOracleContainer(ctx context.Context, params *core.ConnectionParams) (*OracleContainer, error) { 24 | const ( 25 | password = "password" 26 | appUser = "tester" 27 | port = "1521/tcp" 28 | memoryLimitGB = 3 * 1024 * 1024 * 1024 29 | ) 30 | 31 | seedFile, err := GetTestDataFile("oracle_seed.sql") 32 | if err != nil { 33 | return nil, err 34 | } 35 | 36 | req := tc.ContainerRequest{ 37 | Image: "gvenzl/oracle-free:23.6-slim-faststart", 38 | ExposedPorts: []string{port}, 39 | Env: map[string]string{ 40 | "ORACLE_PASSWORD": password, 41 | "APP_USER": appUser, 42 | "APP_USER_PASSWORD": password, 43 | }, 44 | WaitingFor: wait.ForLog("DATABASE IS READY TO USE!").WithStartupTimeout(5 * time.Minute), 45 | Resources: container.Resources{Memory: memoryLimitGB}, 46 | Files: []tc.ContainerFile{ 47 | { 48 | Reader: seedFile, 49 | ContainerFilePath: "/docker-entrypoint-initdb.d/" + seedFile.Name(), 50 | FileMode: 0o755, 51 | }, 52 | }, 53 | } 54 | 55 | ctr, err := tc.GenericContainer(ctx, tc.GenericContainerRequest{ 56 | ContainerRequest: req, 57 | ProviderType: GetContainerProvider(), 58 | Started: true, 59 | }) 60 | if err != nil { 61 | return nil, err 62 | } 63 | 64 | host, err := ctr.Host(ctx) 65 | if err != nil { 66 | return nil, err 67 | } 68 | 69 | mPort, err := ctr.MappedPort(ctx, port) 70 | if err != nil { 71 | return nil, err 72 | } 73 | 74 | connURL := fmt.Sprintf("oracle://%s:%s@%s:%d/FREEPDB1", appUser, password, host, mPort.Int()) 75 | if params.Type == "" { 76 | params.Type = "oracle" 77 | } 78 | 79 | if params.URL == "" { 80 | params.URL = connURL 81 | } 82 | 83 | driver, err := adapters.NewConnection(params) 84 | if err != nil { 85 | return nil, err 86 | } 87 | 88 | return &OracleContainer{ConnURL: connURL, Driver: driver}, nil 89 | } 90 | 91 | // NewDriver helper function to create a new driver with the connection URL. 92 | func (p *OracleContainer) NewDriver(params *core.ConnectionParams) (*core.Connection, error) { 93 | if params.URL == "" { 94 | params.URL = p.ConnURL 95 | } 96 | if params.Type == "" { 97 | params.Type = "oracle" 98 | } 99 | 100 | return adapters.NewConnection(params) 101 | } 102 | -------------------------------------------------------------------------------- /dbee/tests/testhelpers/postgres.go: -------------------------------------------------------------------------------- 1 | package testhelpers 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/kndndrj/nvim-dbee/dbee/adapters" 7 | "github.com/kndndrj/nvim-dbee/dbee/core" 8 | tc "github.com/testcontainers/testcontainers-go" 9 | tcpsql "github.com/testcontainers/testcontainers-go/modules/postgres" 10 | ) 11 | 12 | type PostgresContainer struct { 13 | *tcpsql.PostgresContainer 14 | ConnURL string 15 | Driver *core.Connection 16 | } 17 | 18 | // NewPostgresContainer creates a new postgres container with 19 | // default adapter and connection. The params.URL is overwritten. 20 | func NewPostgresContainer(ctx context.Context, params *core.ConnectionParams) (*PostgresContainer, error) { 21 | seedFile, err := GetTestDataFile("postgres_seed.sql") 22 | if err != nil { 23 | return nil, err 24 | } 25 | 26 | ctr, err := tcpsql.Run( 27 | ctx, 28 | "postgres:16-alpine", 29 | tcpsql.BasicWaitStrategies(), 30 | tc.CustomizeRequest(tc.GenericContainerRequest{ 31 | ProviderType: GetContainerProvider(), 32 | }), 33 | tcpsql.WithInitScripts(seedFile.Name()), 34 | tcpsql.WithDatabase("dev"), 35 | ) 36 | if err != nil { 37 | return nil, err 38 | } 39 | connURL, err := ctr.ConnectionString(ctx, "sslmode=disable") 40 | if err != nil { 41 | return nil, err 42 | } 43 | 44 | if params.Type == "" { 45 | params.Type = "postgres" 46 | } 47 | 48 | if params.URL == "" { 49 | params.URL = connURL 50 | } 51 | 52 | driver, err := adapters.NewConnection(params) 53 | if err != nil { 54 | return nil, err 55 | } 56 | 57 | return &PostgresContainer{ 58 | PostgresContainer: ctr, 59 | ConnURL: connURL, 60 | Driver: driver, 61 | }, nil 62 | } 63 | 64 | // NewDriver helper function to create a new driver with the connection URL. 65 | func (p *PostgresContainer) NewDriver(params *core.ConnectionParams) (*core.Connection, error) { 66 | if params.URL == "" { 67 | params.URL = p.ConnURL 68 | } 69 | if params.Type == "" { 70 | params.Type = "postgres" 71 | } 72 | 73 | return adapters.NewConnection(params) 74 | } 75 | -------------------------------------------------------------------------------- /dbee/tests/testhelpers/sqlite.go: -------------------------------------------------------------------------------- 1 | package testhelpers 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "path/filepath" 7 | "strings" 8 | "time" 9 | 10 | "github.com/docker/docker/api/types/container" 11 | "github.com/kndndrj/nvim-dbee/dbee/adapters" 12 | "github.com/kndndrj/nvim-dbee/dbee/core" 13 | tc "github.com/testcontainers/testcontainers-go" 14 | "github.com/testcontainers/testcontainers-go/wait" 15 | ) 16 | 17 | type SQLiteContainer struct { 18 | tc.Container 19 | ConnURL string 20 | Driver *core.Connection 21 | TempDir string 22 | } 23 | 24 | // NewSQLiteContainer creates a new sqlite container with 25 | // default adapter and connection. The params.URL is overwritten. 26 | // It uses a temporary directory (usually the test suite tempDir) to store the db file. 27 | // The tmpDir is then mounted to the container and all the dependencies are installed 28 | // in the container file, while still being able to connect to the db file in the host. 29 | func NewSQLiteContainer(ctx context.Context, params *core.ConnectionParams, tmpDir string) (*SQLiteContainer, error) { 30 | seedFile, err := GetTestDataFile("sqlite_seed.sql") 31 | if err != nil { 32 | return nil, err 33 | } 34 | 35 | dbName, containerDBPath := "test.db", "/container/db" 36 | entrypointCmd := []string{ 37 | "apk add sqlite", 38 | fmt.Sprintf("sqlite3 %s/%s < %s", containerDBPath, dbName, seedFile.Name()), 39 | "echo 'ready'", 40 | "tail -f /dev/null", // hack to keep the container running indefinitely 41 | } 42 | 43 | req := tc.ContainerRequest{ 44 | Image: "alpine:3.21", 45 | Files: []tc.ContainerFile{ 46 | { 47 | Reader: seedFile, 48 | ContainerFilePath: seedFile.Name(), 49 | FileMode: 0o755, 50 | }, 51 | }, 52 | HostConfigModifier: func(hc *container.HostConfig) { 53 | hc.Binds = append(hc.Binds, fmt.Sprintf("%s:%s", tmpDir, containerDBPath)) 54 | }, 55 | Cmd: []string{"sh", "-c", strings.Join(entrypointCmd, " && ")}, 56 | WaitingFor: wait.ForLog("ready").WithStartupTimeout(5 * time.Second), 57 | } 58 | 59 | ctr, err := tc.GenericContainer(ctx, tc.GenericContainerRequest{ 60 | ContainerRequest: req, 61 | ProviderType: GetContainerProvider(), 62 | Started: true, 63 | }) 64 | if err != nil { 65 | return nil, err 66 | } 67 | 68 | if params.Type == "" { 69 | params.Type = "sqlite" 70 | } 71 | 72 | connURL := filepath.Join(tmpDir, dbName) 73 | if params.URL == "" { 74 | params.URL = connURL 75 | } 76 | 77 | driver, err := adapters.NewConnection(params) 78 | if err != nil { 79 | return nil, err 80 | } 81 | 82 | return &SQLiteContainer{ 83 | Container: ctr, 84 | ConnURL: connURL, 85 | Driver: driver, 86 | TempDir: tmpDir, 87 | }, nil 88 | } 89 | 90 | // NewDriver helper function to create a new driver with the connection URL. 91 | func (p *SQLiteContainer) NewDriver(params *core.ConnectionParams) (*core.Connection, error) { 92 | if params.URL == "" { 93 | params.URL = p.ConnURL 94 | } 95 | if params.Type == "" { 96 | params.Type = "sqlite" 97 | } 98 | 99 | return adapters.NewConnection(params) 100 | } 101 | -------------------------------------------------------------------------------- /dbee/tests/testhelpers/sqlserver.go: -------------------------------------------------------------------------------- 1 | package testhelpers 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/kndndrj/nvim-dbee/dbee/adapters" 7 | "github.com/kndndrj/nvim-dbee/dbee/core" 8 | tc "github.com/testcontainers/testcontainers-go" 9 | tcmssql "github.com/testcontainers/testcontainers-go/modules/mssql" 10 | ) 11 | 12 | type MSSQLServerContainer struct { 13 | *tcmssql.MSSQLServerContainer 14 | ConnURL string 15 | Driver *core.Connection 16 | } 17 | 18 | // NewSQLServerContainer creates a new MS SQL Server container with 19 | // default adapter and connection. The params.URL is overwritten. 20 | func NewSQLServerContainer(ctx context.Context, params *core.ConnectionParams) (*MSSQLServerContainer, error) { 21 | const password = "H3ll0@W0rld" 22 | seedFile, err := GetTestDataFile("sqlserver_seed.sql") 23 | if err != nil { 24 | return nil, err 25 | } 26 | 27 | ctr, err := tcmssql.Run( 28 | ctx, 29 | "mcr.microsoft.com/mssql/server:2022-CU17-ubuntu-22.04", 30 | tcmssql.WithAcceptEULA(), // ok for testing purposes 31 | tcmssql.WithPassword(password), 32 | tc.CustomizeRequest(tc.GenericContainerRequest{ 33 | ContainerRequest: tc.ContainerRequest{ 34 | Files: []tc.ContainerFile{ 35 | { 36 | Reader: seedFile, 37 | ContainerFilePath: seedFile.Name(), 38 | FileMode: 0o644, 39 | }, 40 | }, 41 | }, 42 | ProviderType: GetContainerProvider(), 43 | }), 44 | tc.WithAfterReadyCommand( 45 | tc.NewRawCommand([]string{ 46 | "/opt/mssql-tools18/bin/sqlcmd", 47 | "-S", "localhost", 48 | "-U", "sa", 49 | "-P", password, 50 | "-No", 51 | "-i", seedFile.Name(), 52 | }), 53 | ), 54 | ) 55 | if err != nil { 56 | return nil, err 57 | } 58 | 59 | connURL, err := ctr.ConnectionString(ctx, "encrypt=false", "TrustServerCertificate=true") 60 | if err != nil { 61 | return nil, err 62 | } 63 | 64 | if params.Type == "" { 65 | params.Type = "mssql" 66 | } 67 | 68 | if params.URL == "" { 69 | params.URL = connURL 70 | } 71 | 72 | driver, err := adapters.NewConnection(params) 73 | if err != nil { 74 | return nil, err 75 | } 76 | 77 | return &MSSQLServerContainer{ 78 | MSSQLServerContainer: ctr, 79 | ConnURL: connURL, 80 | Driver: driver, 81 | }, nil 82 | } 83 | 84 | // NewDriver helper function to create a new driver with the connection URL. 85 | func (p *MSSQLServerContainer) NewDriver(params *core.ConnectionParams) (*core.Connection, error) { 86 | if params.URL == "" { 87 | params.URL = p.ConnURL 88 | } 89 | if params.Type == "" { 90 | params.Type = "mssql" 91 | } 92 | 93 | return adapters.NewConnection(params) 94 | } 95 | -------------------------------------------------------------------------------- /lua/dbee.lua: -------------------------------------------------------------------------------- 1 | local install = require("dbee.install") 2 | local api = require("dbee.api") 3 | local config = require("dbee.config") 4 | 5 | ---@toc dbee.ref.contents 6 | 7 | ---@mod dbee.ref Dbee Reference 8 | ---@brief [[ 9 | ---Database Client for NeoVim. 10 | ---@brief ]] 11 | 12 | local dbee = { 13 | api = { 14 | core = api.core, 15 | ui = api.ui, 16 | }, 17 | } 18 | 19 | ---Setup function. 20 | ---Needs to be called before calling any other function. 21 | ---@param cfg? Config 22 | function dbee.setup(cfg) 23 | -- merge with defaults 24 | local merged = config.merge_with_default(cfg) 25 | 26 | -- validate config 27 | config.validate(merged) 28 | 29 | api.setup(merged) 30 | end 31 | 32 | ---Toggle dbee UI. 33 | function dbee.toggle() 34 | if api.current_config().window_layout:is_open() then 35 | dbee.close() 36 | else 37 | dbee.open() 38 | end 39 | end 40 | 41 | ---Open dbee UI. If already opened, reset window layout. 42 | function dbee.open() 43 | if api.current_config().window_layout:is_open() then 44 | return api.current_config().window_layout:reset() 45 | end 46 | api.current_config().window_layout:open() 47 | end 48 | 49 | ---Close dbee UI. 50 | function dbee.close() 51 | if not api.current_config().window_layout:is_open() then 52 | return 53 | end 54 | api.current_config().window_layout:close() 55 | end 56 | 57 | ---Check if dbee UI is open or not. 58 | ---@return boolean 59 | function dbee.is_open() 60 | return api.current_config().window_layout:is_open() 61 | end 62 | 63 | ---Execute a query on current connection. 64 | ---Convenience wrapper around some api functions that executes a query on 65 | ---current connection and pipes the output to result UI. 66 | ---@param query string 67 | function dbee.execute(query) 68 | local conn = api.core.get_current_connection() 69 | if not conn then 70 | error("no connection currently selected") 71 | end 72 | 73 | local call = api.core.connection_execute(conn.id, query) 74 | api.ui.result_set_call(call) 75 | 76 | dbee.open() 77 | end 78 | 79 | ---Store currently displayed result. 80 | ---Convenience wrapper around some api functions. 81 | ---@param format string format of the output -> "csv"|"json"|"table" 82 | ---@param output string where to pipe the results -> "file"|"yank"|"buffer" 83 | ---@param opts { from: integer, to: integer, extra_arg: any } 84 | function dbee.store(format, output, opts) 85 | local call = api.ui.result_get_call() 86 | if not call then 87 | error("no current call to store") 88 | end 89 | 90 | api.core.call_store_result(call.id, format, output, opts) 91 | end 92 | 93 | ---Supported install commands. 94 | ---@alias install_command 95 | ---| '"wget"' 96 | ---| '"curl"' 97 | ---| '"bitsadmin"' 98 | ---| '"go"' 99 | ---| '"cgo"' 100 | 101 | ---Install dbee backend binary. 102 | ---@param command? install_command Preffered install command 103 | ---@see install_command 104 | function dbee.install(command) 105 | install.exec(command) 106 | end 107 | 108 | return dbee 109 | -------------------------------------------------------------------------------- /lua/dbee/api/__register.lua: -------------------------------------------------------------------------------- 1 | -- This file is automatically generated using "dbee -manifest " 2 | -- DO NOT EDIT! 3 | 4 | return function() 5 | -- Register host 6 | vim.fn["remote#host#Register"]("nvim_dbee", "x", function() 7 | return vim.fn.jobstart({ "dbee" }, { 8 | rpc = true, 9 | detach = true, 10 | on_stderr = function(_, data, _) 11 | for _, line in ipairs(data) do 12 | print(line) 13 | end 14 | end, 15 | }) 16 | end) 17 | 18 | -- Manifest 19 | vim.fn["remote#host#RegisterPlugin"]("nvim_dbee", "0", { 20 | { type = "function", name = "DbeeAddHelpers", sync = true, opts = vim.empty_dict() }, 21 | { type = "function", name = "DbeeCallCancel", sync = true, opts = vim.empty_dict() }, 22 | { type = "function", name = "DbeeCallDisplayResult", sync = true, opts = vim.empty_dict() }, 23 | { type = "function", name = "DbeeCallStoreResult", sync = true, opts = vim.empty_dict() }, 24 | { type = "function", name = "DbeeConnectionExecute", sync = true, opts = vim.empty_dict() }, 25 | { type = "function", name = "DbeeConnectionGetCalls", sync = true, opts = vim.empty_dict() }, 26 | { type = "function", name = "DbeeConnectionGetColumns", sync = true, opts = vim.empty_dict() }, 27 | { type = "function", name = "DbeeConnectionGetHelpers", sync = true, opts = vim.empty_dict() }, 28 | { type = "function", name = "DbeeConnectionGetParams", sync = true, opts = vim.empty_dict() }, 29 | { type = "function", name = "DbeeConnectionGetStructure", sync = true, opts = vim.empty_dict() }, 30 | { type = "function", name = "DbeeConnectionListDatabases", sync = true, opts = vim.empty_dict() }, 31 | { type = "function", name = "DbeeConnectionSelectDatabase", sync = true, opts = vim.empty_dict() }, 32 | { type = "function", name = "DbeeCreateConnection", sync = true, opts = vim.empty_dict() }, 33 | { type = "function", name = "DbeeDeleteConnection", sync = true, opts = vim.empty_dict() }, 34 | { type = "function", name = "DbeeGetConnections", sync = true, opts = vim.empty_dict() }, 35 | { type = "function", name = "DbeeGetCurrentConnection", sync = true, opts = vim.empty_dict() }, 36 | { type = "function", name = "DbeeSetCurrentConnection", sync = true, opts = vim.empty_dict() }, 37 | }) 38 | end 39 | -------------------------------------------------------------------------------- /lua/dbee/api/init.lua: -------------------------------------------------------------------------------- 1 | return { 2 | core = require("dbee.api.core"), 3 | ui = require("dbee.api.ui"), 4 | setup = require("dbee.api.state").setup, 5 | current_config = require("dbee.api.state").config, 6 | } 7 | -------------------------------------------------------------------------------- /lua/dbee/api/state.lua: -------------------------------------------------------------------------------- 1 | local floats = require("dbee.ui.common.floats") 2 | local DrawerUI = require("dbee.ui.drawer") 3 | local EditorUI = require("dbee.ui.editor") 4 | local ResultUI = require("dbee.ui.result") 5 | local CallLogUI = require("dbee.ui.call_log") 6 | local Handler = require("dbee.handler") 7 | local install = require("dbee.install") 8 | local register = require("dbee.api.__register") 9 | 10 | -- public and private module objects 11 | local M = {} 12 | local m = {} 13 | 14 | -- is core set up? 15 | m.core_loaded = false 16 | -- is ui set up? 17 | m.ui_loaded = false 18 | -- was setup function called? 19 | m.setup_called = false 20 | ---@type Config 21 | m.config = {} 22 | 23 | local function setup_handler() 24 | if m.core_loaded then 25 | return 26 | end 27 | 28 | if not m.setup_called then 29 | error("setup() has not been called yet") 30 | end 31 | 32 | -- register remote plugin 33 | register() 34 | 35 | -- add install binary to path 36 | local pathsep = ":" 37 | if vim.fn.has("win32") == 1 then 38 | pathsep = ";" 39 | end 40 | vim.env.PATH = install.dir() .. pathsep .. vim.env.PATH 41 | 42 | m.handler = Handler:new(m.config.sources) 43 | m.handler:add_helpers(m.config.extra_helpers) 44 | 45 | -- activate default connection if present 46 | if m.config.default_connection then 47 | pcall(m.handler.set_current_connection, m.handler, m.config.default_connection) 48 | end 49 | 50 | m.core_loaded = true 51 | end 52 | 53 | local function setup_ui() 54 | if m.ui_loaded then 55 | return 56 | end 57 | 58 | setup_handler() 59 | 60 | -- configure options for floating windows 61 | floats.configure(m.config.float_options) 62 | 63 | -- initiate all UI elements 64 | m.result = ResultUI:new(m.handler, m.config.result) 65 | m.call_log = CallLogUI:new(m.handler, m.result, m.config.call_log) 66 | m.editor = EditorUI:new(m.handler, m.result, m.config.editor) 67 | m.drawer = DrawerUI:new(m.handler, m.editor, m.result, m.config.drawer) 68 | 69 | m.ui_loaded = true 70 | end 71 | 72 | ---@param cfg Config 73 | function M.setup(cfg) 74 | if m.setup_called then 75 | error("setup() can only be called once") 76 | end 77 | m.config = cfg 78 | 79 | m.setup_called = true 80 | end 81 | 82 | ---@return boolean 83 | function M.is_core_loaded() 84 | return m.core_loaded 85 | end 86 | 87 | ---@return boolean 88 | function M.is_ui_loaded() 89 | return m.ui_loaded 90 | end 91 | 92 | ---@return Handler 93 | function M.handler() 94 | setup_handler() 95 | return m.handler 96 | end 97 | 98 | ---@return EditorUI 99 | function M.editor() 100 | setup_ui() 101 | return m.editor 102 | end 103 | 104 | ---@return CallLogUI 105 | function M.call_log() 106 | setup_ui() 107 | return m.call_log 108 | end 109 | 110 | ---@return DrawerUI 111 | function M.drawer() 112 | setup_ui() 113 | return m.drawer 114 | end 115 | 116 | ---@return ResultUI 117 | function M.result() 118 | setup_ui() 119 | return m.result 120 | end 121 | 122 | ---@return Config 123 | function M.config() 124 | return m.config 125 | end 126 | 127 | return M 128 | -------------------------------------------------------------------------------- /lua/dbee/doc.lua: -------------------------------------------------------------------------------- 1 | ---@mod dbee.ref.types Types 2 | ---@brief [[ 3 | ---Overview of types used in DBee API. 4 | ---@brief ]] 5 | 6 | ---@divider - 7 | ---@tag dbee.ref.types.table 8 | ---@brief [[ 9 | ---Table related types 10 | ---@brief ]] 11 | 12 | ---Table column 13 | ---@class Column 14 | ---@field name string name of the column 15 | ---@field type string database type of the column 16 | 17 | ---Table Materialization. 18 | ---@alias materialization 19 | ---| '"table"' 20 | ---| '"view"' 21 | 22 | ---Options for gathering table specific info. 23 | ---@class TableOpts 24 | ---@field table string 25 | ---@field schema string 26 | ---@field materialization materialization 27 | 28 | ---Table helpers queries by name. 29 | ---@alias table_helpers table 30 | 31 | ---@divider - 32 | ---@tag dbee.ref.types.call 33 | ---@brief [[ 34 | ---Call related types. 35 | ---@brief ]] 36 | 37 | ---ID of a call. 38 | ---@alias call_id string 39 | 40 | ---State of a call. 41 | ---@alias call_state 42 | ---| '"unknown"' 43 | ---| '"executing"' 44 | ---| '"executing_failed"' 45 | ---| '"retrieving"' 46 | ---| '"retrieving_failed"' 47 | ---| '"archived"' 48 | ---| '"archive_failed"' 49 | ---| '"canceled"' 50 | 51 | ---Details and stats of a single call to database. 52 | ---@class CallDetails 53 | ---@field id call_id 54 | ---@field time_taken_us integer duration (time period) in microseconds 55 | ---@field query string 56 | ---@field state call_state 57 | ---@field timestamp_us integer time in microseconds 58 | ---@field error? string error message in case of error 59 | 60 | ---@divider - 61 | ---@tag dbee.ref.types.connection 62 | ---@brief [[ 63 | ---Connection related types. 64 | ---@brief ]] 65 | 66 | ---ID of a connection. 67 | ---@alias connection_id string 68 | 69 | ---Parameters of a connection. 70 | ---@class ConnectionParams 71 | ---@field id connection_id 72 | ---@field name string 73 | ---@field type string 74 | ---@field url string 75 | 76 | ---@divider - 77 | ---@tag dbee.ref.types.structure 78 | ---@brief [[ 79 | ---Database structure related types. 80 | ---@brief ]] 81 | 82 | ---Type of node in database structure. 83 | ---@alias structure_type 84 | ---| '""' 85 | ---| '"table"' 86 | ---| '"history"' 87 | ---| '"database_switch"' 88 | ---| '"view"' 89 | 90 | ---Structure of database. 91 | ---@class DBStructure 92 | ---@field name string display name 93 | ---@field type structure_type type of node in structure 94 | ---@field schema string? parent schema 95 | ---@field children DBStructure[]? child layout nodes 96 | 97 | ---@divider - 98 | ---@tag dbee.ref.types.events 99 | ---@brief [[ 100 | ---Event related types. 101 | ---@brief ]] 102 | 103 | ---Avaliable core events. 104 | ---@alias core_event_name 105 | ---| '"call_state_changed"' {call} 106 | ---| '"current_connection_changed"' {conn_id} 107 | ---| '"database_selected"' {conn_id, database_name} 108 | 109 | ---Available editor events. 110 | ---@alias editor_event_name 111 | ---| '"note_state_changed"' {note_id} 112 | ---| '"note_removed"' {note_id} 113 | ---| '"note_created"' {note_id} 114 | ---| '"current_note_changed"' {note_id} 115 | 116 | ---Event handler function. 117 | ---@alias event_listener fun(data: any) 118 | 119 | local M = {} 120 | return M 121 | -------------------------------------------------------------------------------- /lua/dbee/handler/__events.lua: -------------------------------------------------------------------------------- 1 | -- This package is used for triggering lua callbacks from go. 2 | -- It uses unique ids to register the callbacks and trigger them. 3 | local M = {} 4 | 5 | ---@type table 6 | local callbacks = {} 7 | 8 | ---@param event core_event_name event name to register the callback for 9 | ---@param cb event_listener callback function - "data" argument type depends on the event 10 | function M.register(event, cb) 11 | callbacks[event] = callbacks[event] or {} 12 | table.insert(callbacks[event], cb) 13 | end 14 | 15 | ---@param event core_event_name 16 | ---@param data any 17 | function M.trigger(event, data) 18 | vim.schedule(function() 19 | local cbs = callbacks[event] or {} 20 | for _, cb in ipairs(cbs) do 21 | cb(data) 22 | end 23 | end) 24 | end 25 | 26 | return M 27 | -------------------------------------------------------------------------------- /lua/dbee/health.lua: -------------------------------------------------------------------------------- 1 | local install = require("dbee.install") 2 | 3 | local M = {} 4 | 5 | ---@param cmd string 6 | ---@return string 7 | local function run_cmd(cmd) 8 | local handle = assert(io.popen(cmd)) 9 | local result = handle:read("*all") 10 | handle:close() 11 | 12 | return string.gsub(result, "\n", "") or "" 13 | end 14 | 15 | ---@return string _ path of git repo 16 | local function repo() 17 | local p, _ = debug.getinfo(1).source:sub(2):gsub("/lua/dbee/health.lua$", "/") 18 | return p 19 | end 20 | 21 | -- Gets a git hash from which the go binary is compiled. 22 | ---@return string 23 | local function get_go_hash() 24 | return run_cmd(string.format("%s -version", install.bin())) 25 | end 26 | 27 | -- Gets currently checked out git hash. 28 | ---@return string 29 | local function get_current_hash() 30 | return run_cmd(string.format("git -C %q rev-parse HEAD", repo())) 31 | end 32 | 33 | -- Gets git hash of the install manifest 34 | ---@return string 35 | local function get_manifest_hash() 36 | return install.version() 37 | end 38 | 39 | function M.check() 40 | vim.health.start("DBee report") 41 | 42 | if vim.fn.executable(install.bin()) ~= 1 then 43 | vim.health.error("Binary not executable: " .. install.bin() .. ".") 44 | return 45 | end 46 | 47 | if vim.fn.executable("git") ~= 1 then 48 | vim.health.warn("Git not installed -- could not determine binary version.") 49 | return 50 | end 51 | 52 | local go_hash = get_go_hash() 53 | local current_hash = get_current_hash() 54 | local manifest_hash = get_manifest_hash() 55 | 56 | if go_hash == "unknown" then 57 | vim.health.error("Could not determine binary version.") 58 | return 59 | end 60 | 61 | if go_hash == current_hash then 62 | vim.health.ok("Binary version matches version of current HEAD.") 63 | return 64 | elseif go_hash == manifest_hash then 65 | vim.health.ok("Binary version matches version of install manifest.") 66 | return 67 | end 68 | 69 | vim.health.error( 70 | string.format( 71 | "Binary version %q doesn't match either:\n - current hash: %q or\n - hash of install manifest %q.", 72 | go_hash, 73 | current_hash, 74 | manifest_hash 75 | ) 76 | ) 77 | end 78 | 79 | return M 80 | -------------------------------------------------------------------------------- /lua/dbee/install/__manifest.lua: -------------------------------------------------------------------------------- 1 | -- This file is automatically generated using CI pipeline 2 | -- DO NOT EDIT! 3 | local M = {} 4 | 5 | -- Links to binary releases 6 | M.urls = { 7 | ["android/amd64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_android_amd64.tar.gz", 8 | ["android/arm64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_android_arm64.tar.gz", 9 | ["darwin/amd64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_darwin_amd64.tar.gz", 10 | ["darwin/arm64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_darwin_arm64.tar.gz", 11 | ["dragonfly/amd64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_dragonfly_amd64.tar.gz", 12 | ["freebsd/386"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_freebsd_386.tar.gz", 13 | ["freebsd/amd64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_freebsd_amd64.tar.gz", 14 | ["freebsd/arm"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_freebsd_arm.tar.gz", 15 | ["freebsd/arm64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_freebsd_arm64.tar.gz", 16 | ["freebsd/riscv64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_freebsd_riscv64.tar.gz", 17 | ["illumos/amd64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_illumos_amd64.tar.gz", 18 | ["linux/386"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_linux_386.tar.gz", 19 | ["linux/amd64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_linux_amd64.tar.gz", 20 | ["linux/arm"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_linux_arm.tar.gz", 21 | ["linux/arm64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_linux_arm64.tar.gz", 22 | ["linux/loong64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_linux_loong64.tar.gz", 23 | ["linux/mips64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_linux_mips64.tar.gz", 24 | ["linux/mips64le"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_linux_mips64le.tar.gz", 25 | ["linux/ppc64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_linux_ppc64.tar.gz", 26 | ["linux/ppc64le"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_linux_ppc64le.tar.gz", 27 | ["linux/riscv64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_linux_riscv64.tar.gz", 28 | ["linux/s390x"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_linux_s390x.tar.gz", 29 | ["netbsd/386"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_netbsd_386.tar.gz", 30 | ["netbsd/amd64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_netbsd_amd64.tar.gz", 31 | ["netbsd/arm"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_netbsd_arm.tar.gz", 32 | ["netbsd/arm64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_netbsd_arm64.tar.gz", 33 | ["openbsd/386"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_openbsd_386.tar.gz", 34 | ["openbsd/amd64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_openbsd_amd64.tar.gz", 35 | ["openbsd/arm"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_openbsd_arm.tar.gz", 36 | ["openbsd/arm64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_openbsd_arm64.tar.gz", 37 | ["solaris/amd64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_solaris_amd64.tar.gz", 38 | ["windows/386"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_windows_386.tar.gz", 39 | ["windows/amd64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_windows_amd64.tar.gz", 40 | ["windows/arm"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_windows_arm.tar.gz", 41 | ["windows/arm64"] = "https://github.com/kndndrj/nvim-dbee/releases/download/v0.1.9/dbee_windows_arm64.tar.gz", 42 | } 43 | 44 | -- Git sha of compiled binaries 45 | M.version = "af5075f31ede9e7d76c87babdee0f70340061660" 46 | 47 | return M 48 | -------------------------------------------------------------------------------- /lua/dbee/layouts/tools.lua: -------------------------------------------------------------------------------- 1 | ---@alias _layout { type: string, winid: integer, bufnr: integer, win_opts: { string: any}, children: _layout[] } 2 | 3 | ---@alias layout_egg { layout: _layout, restore: string } 4 | 5 | -- vim.fn.winlayout() example structure: 6 | -- { "row", { { "leaf", winid }, { "col", { { "leaf", winid }, { "leaf", winid } } } } } 7 | 8 | local M = {} 9 | 10 | -- list all non-floating windows from the current tabpage 11 | ---@return integer[] # list of non floating window ids 12 | local function list_non_floating_wins() 13 | return vim.fn.filter(vim.api.nvim_tabpage_list_wins(vim.api.nvim_get_current_tabpage()), function(_, v) 14 | return vim.api.nvim_win_get_config(v).relative == "" 15 | end) 16 | end 17 | 18 | -- makes the window the only one on screen 19 | -- same as ":only" except ignores floating windows 20 | ---@param winid integer 21 | function M.make_only(winid) 22 | if not winid or winid == 0 then 23 | winid = vim.api.nvim_get_current_win() 24 | end 25 | 26 | for _, wid in ipairs(list_non_floating_wins()) do 27 | if wid ~= winid then 28 | local winnr = vim.fn.win_id2win(wid) 29 | vim.cmd(winnr .. "wincmd c") 30 | end 31 | end 32 | end 33 | 34 | -- exact clone of the builtin "winrestcmd()" with exclusion of floating windows 35 | -- https://github.com/neovim/neovim/blob/fcf3519c65a2d6736de437f686e788684a6c8564/src/nvim/eval/window.c#L770 36 | ---@return string 37 | local function winrestcmd() 38 | local cmd = "" 39 | 40 | -- Do this twice to handle some window layouts properly. 41 | for _ = 1, 2 do 42 | local winnr = 1 43 | for _, winid in ipairs(list_non_floating_wins()) do 44 | cmd = string.format("%s%dresize %d|", cmd, winnr, vim.api.nvim_win_get_height(winid)) 45 | cmd = string.format("%svert %dresize %d|", cmd, winnr, vim.api.nvim_win_get_width(winid)) 46 | winnr = winnr + 1 47 | end 48 | end 49 | 50 | return cmd 51 | end 52 | 53 | -- add bufnr to leaf 54 | local function add_details(layout) 55 | if layout[1] == "leaf" then 56 | local win = layout[2] 57 | 58 | -- window options 59 | local all_options = vim.api.nvim_get_all_options_info() 60 | local v = vim.wo[win] 61 | local options = {} 62 | for key, val in pairs(all_options) do 63 | if val.global_local == false and val.scope == "win" then 64 | options[key] = v[key] 65 | end 66 | end 67 | 68 | -- create dict structure with added buffer and window opts 69 | ---@type _layout 70 | local l = { 71 | type = layout[1], 72 | winid = win, 73 | bufnr = vim.fn.winbufnr(win), 74 | win_opts = options, 75 | } 76 | return l 77 | else 78 | local children = {} 79 | for _, child_layout in ipairs(layout[2]) do 80 | table.insert(children, add_details(child_layout)) 81 | end 82 | return { type = layout[1], children = children } 83 | end 84 | end 85 | 86 | ---@return layout_egg layout egg (use with restore()) 87 | function M.save() 88 | local layout = vim.fn.winlayout() 89 | local restore_cmd = winrestcmd() 90 | 91 | layout = add_details(layout) 92 | 93 | return { layout = layout, restore = restore_cmd } 94 | end 95 | 96 | ---@param layout _layout 97 | local function apply_layout(layout) 98 | if layout.type == "leaf" then 99 | -- open the previous buffer 100 | if vim.fn.bufexists(layout.bufnr) == 1 then 101 | vim.cmd("b " .. layout.bufnr) 102 | end 103 | -- apply window options 104 | for opt, val in pairs(layout.win_opts) do 105 | if val ~= nil then 106 | vim.wo[opt] = val 107 | end 108 | end 109 | else 110 | -- split cols or rows, split n-1 times 111 | local split_method = "rightbelow vsplit" 112 | if layout.type == "col" then 113 | split_method = "rightbelow split" 114 | end 115 | 116 | local wins = { vim.fn.win_getid() } 117 | 118 | for i in ipairs(layout.children) do 119 | if i ~= 1 then 120 | vim.cmd(split_method) 121 | table.insert(wins, vim.fn.win_getid()) 122 | end 123 | end 124 | 125 | -- recursive into child windows 126 | for index, win in ipairs(wins) do 127 | vim.fn.win_gotoid(win) 128 | apply_layout(layout.children[index]) 129 | end 130 | end 131 | end 132 | 133 | ---@param egg layout_egg layout to restore 134 | function M.restore(egg) 135 | egg = egg or {} 136 | 137 | if not egg.layout or not egg.restore then 138 | return 139 | end 140 | 141 | -- make a new window and set it as the only one 142 | vim.cmd("new") 143 | M.make_only(0) 144 | local tmp_buf = vim.api.nvim_get_current_buf() 145 | 146 | -- apply layout and perform resize_cmd 147 | apply_layout(egg.layout) 148 | vim.cmd(egg.restore) 149 | 150 | -- delete temporary buffer 151 | vim.cmd("bd " .. tmp_buf) 152 | end 153 | 154 | return M 155 | -------------------------------------------------------------------------------- /lua/dbee/ui/common/init.lua: -------------------------------------------------------------------------------- 1 | local floats = require("dbee.ui.common.floats") 2 | local utils = require("dbee.utils") 3 | 4 | local M = {} 5 | 6 | -- expose floats 7 | M.float_editor = floats.editor 8 | M.float_hover = floats.hover 9 | M.float_prompt = floats.prompt 10 | 11 | -- Creates a blank hidden buffer. 12 | ---@param name string 13 | ---@param opts? table buffer options 14 | ---@return integer bufnr 15 | function M.create_blank_buffer(name, opts) 16 | opts = opts or {} 17 | 18 | local bufnr = vim.api.nvim_create_buf(false, true) 19 | -- try setting buffer name - fallback to random string 20 | local ok = pcall(vim.api.nvim_buf_set_name, bufnr, name) 21 | if not ok then 22 | pcall(vim.api.nvim_buf_set_name, bufnr, name .. "-" .. utils.random_string()) 23 | end 24 | 25 | M.configure_buffer_options(bufnr, opts) 26 | 27 | return bufnr 28 | end 29 | 30 | ---@param bufnr integer 31 | ---@param opts? table buffer options 32 | function M.configure_buffer_options(bufnr, opts) 33 | if not bufnr then 34 | return 35 | end 36 | 37 | opts = opts or {} 38 | 39 | for opt, val in pairs(opts) do 40 | vim.api.nvim_buf_set_option(bufnr, opt, val) 41 | end 42 | end 43 | 44 | ---@param winid integer 45 | ---@param opts? table window options 46 | function M.configure_window_options(winid, opts) 47 | if not winid then 48 | return 49 | end 50 | opts = opts or {} 51 | 52 | for opt, val in pairs(opts) do 53 | vim.api.nvim_win_set_option(winid, opt, val) 54 | end 55 | end 56 | 57 | -- Sets mappings to the buffer. 58 | ---@param bufnr integer 59 | ---@param actions table 60 | ---@param keymap key_mapping[] 61 | function M.configure_buffer_mappings(bufnr, actions, keymap) 62 | if not bufnr then 63 | return 64 | end 65 | actions = actions or {} 66 | keymap = keymap or {} 67 | 68 | local set_fn = vim.keymap.set 69 | 70 | -- keymaps 71 | local default_opts = { noremap = true, nowait = true } 72 | 73 | for _, km in ipairs(keymap) do 74 | if km.key and km.mode then 75 | local action 76 | if type(km.action) == "string" then 77 | action = actions[km.action] 78 | elseif type(km.action) == "function" then 79 | action = km.action 80 | end 81 | 82 | if action then 83 | local map_opts = km.opts or default_opts 84 | map_opts.buffer = bufnr 85 | set_fn(km.mode, km.key, action, map_opts) 86 | end 87 | end 88 | end 89 | end 90 | 91 | return M 92 | -------------------------------------------------------------------------------- /lua/dbee/ui/drawer/expansion.lua: -------------------------------------------------------------------------------- 1 | local M = {} 2 | 3 | -- applies the expansion on new nodes 4 | ---@param tree NuiTree tree to apply the expansion map to 5 | ---@param expansion table expansion map ( id:is_expanded mapping ) 6 | function M.set(tree, expansion) 7 | -- first pass: load lazy_loaded children 8 | for id, t in pairs(expansion) do 9 | if t then 10 | local node = tree:get_node(id) --[[@as DrawerUINode]] 11 | if node then 12 | -- if function for getting layout exist, call it 13 | if type(node.lazy_children) == "function" then 14 | tree:set_nodes(node.lazy_children(), node.id) 15 | end 16 | end 17 | end 18 | end 19 | 20 | -- second pass: expand nodes 21 | for id, t in pairs(expansion) do 22 | if t then 23 | local node = tree:get_node(id) --[[@as DrawerUINode]] 24 | if node then 25 | node:expand() 26 | end 27 | end 28 | end 29 | end 30 | 31 | -- gets an expansion config to restore the expansion on new nodes 32 | ---@param tree NuiTree 33 | ---@return table 34 | function M.get(tree) 35 | ---@type table 36 | local nodes = {} 37 | 38 | local function process(node) 39 | if node:is_expanded() then 40 | nodes[node:get_id()] = true 41 | end 42 | 43 | if node:has_children() then 44 | for _, n in ipairs(tree:get_nodes(node:get_id())) do 45 | process(n) 46 | end 47 | end 48 | end 49 | 50 | for _, node in ipairs(tree:get_nodes()) do 51 | process(node) 52 | end 53 | 54 | return nodes 55 | end 56 | 57 | return M 58 | -------------------------------------------------------------------------------- /lua/dbee/ui/drawer/menu.lua: -------------------------------------------------------------------------------- 1 | local NuiMenu = require("nui.menu") 2 | local NuiInput = require("nui.input") 3 | 4 | local M = {} 5 | 6 | ---@alias menu_select fun(opts?: { title: string, items: string[], on_confirm: fun(selection: string), on_yank: fun(selection: string) }) 7 | ---@alias menu_input fun(opts?: { title: string, default: string, on_confirm: fun(value: string) }) 8 | 9 | -- Pick items from a list. 10 | ---@param opts { relative_winid: integer, items: string[], on_confirm: fun(item: string), on_yank: fun(item:string), title: string, mappings: key_mapping[] } 11 | function M.select(opts) 12 | opts = opts or {} 13 | if not opts.relative_winid or not vim.api.nvim_win_is_valid(opts.relative_winid) then 14 | error("no window id provided") 15 | end 16 | 17 | local width = vim.api.nvim_win_get_width(opts.relative_winid) 18 | local row, _ = unpack(vim.api.nvim_win_get_cursor(opts.relative_winid)) 19 | 20 | local popup_options = { 21 | relative = { 22 | type = "win", 23 | winid = opts.relative_winid, 24 | }, 25 | position = { 26 | row = row + 1, 27 | col = 0, 28 | }, 29 | size = { 30 | width = width, 31 | }, 32 | zindex = 160, 33 | border = { 34 | style = { "─", "─", "─", "", "─", "─", "─", "" }, 35 | text = { 36 | top = opts.title or "", 37 | top_align = "left", 38 | }, 39 | }, 40 | win_options = { 41 | cursorline = true, 42 | }, 43 | } 44 | 45 | local lines = {} 46 | for _, item in ipairs(opts.items or {}) do 47 | table.insert(lines, NuiMenu.item(item)) 48 | end 49 | 50 | local menu = NuiMenu(popup_options, { 51 | lines = lines, 52 | keymap = { 53 | focus_next = { "j", "", "" }, 54 | focus_prev = { "k", "", "" }, 55 | close = {}, 56 | submit = {}, 57 | }, 58 | on_submit = function() end, 59 | }) 60 | 61 | -- configure mappings 62 | for _, km in ipairs(opts.mappings or {}) do 63 | local action 64 | if km.action == "menu_confirm" then 65 | action = opts.on_confirm 66 | elseif km.action == "menu_yank" then 67 | action = opts.on_yank 68 | elseif km.action == "menu_close" then 69 | action = function() end 70 | end 71 | 72 | local map_opts = km.opts or { noremap = true, nowait = true } 73 | 74 | if action then 75 | menu:map(km.mode, km.key, function() 76 | local item = menu.tree:get_node() 77 | menu:unmount() 78 | if item then 79 | action(item.text) 80 | end 81 | end, map_opts) 82 | end 83 | end 84 | 85 | menu:mount() 86 | end 87 | 88 | -- Ask for input. 89 | ---@param opts { relative_winid: integer, default_value: string, on_confirm: fun(item: string), title: string, mappings: key_mapping[] } 90 | function M.input(opts) 91 | if not opts.relative_winid or not vim.api.nvim_win_is_valid(opts.relative_winid) then 92 | error("no window id provided") 93 | end 94 | 95 | local width = vim.api.nvim_win_get_width(opts.relative_winid) 96 | local row, _ = unpack(vim.api.nvim_win_get_cursor(opts.relative_winid)) 97 | 98 | local popup_options = { 99 | relative = { 100 | type = "win", 101 | winid = opts.relative_winid, 102 | }, 103 | position = { 104 | row = row + 1, 105 | col = 0, 106 | }, 107 | size = { 108 | width = width, 109 | }, 110 | zindex = 160, 111 | border = { 112 | style = { "─", "─", "─", "", "─", "─", "─", "" }, 113 | text = { 114 | top = opts.title or "", 115 | top_align = "left", 116 | }, 117 | }, 118 | win_options = { 119 | cursorline = false, 120 | }, 121 | } 122 | 123 | local input = NuiInput(popup_options, { 124 | default_value = opts.default_value, 125 | on_submit = opts.on_confirm, 126 | }) 127 | 128 | -- configure mappings 129 | for _, km in ipairs(opts.mappings or {}) do 130 | local action 131 | if km.action == "menu_confirm" then 132 | action = opts.on_confirm 133 | elseif km.action == "menu_close" then 134 | action = function() end 135 | end 136 | 137 | local map_opts = km.opts or { noremap = true, nowait = true } 138 | 139 | if action then 140 | input:map(km.mode, km.key, function() 141 | local line = vim.api.nvim_buf_get_lines(input.bufnr, 0, 1, false)[1] 142 | input:unmount() 143 | action(line) 144 | end, map_opts) 145 | end 146 | end 147 | 148 | input:mount() 149 | end 150 | 151 | return M 152 | -------------------------------------------------------------------------------- /lua/dbee/ui/editor/welcome.lua: -------------------------------------------------------------------------------- 1 | local M = {} 2 | 3 | function M.banner() 4 | return { 5 | "-- [ Enter insert mode to clear ]", 6 | "", 7 | "", 8 | "-- Welcome to", 9 | "-- ", 10 | "-- ██████████ ███████████", 11 | "-- ░░███░░░░███ ░░███░░░░░███", 12 | "-- ░███ ░░███ ░███ ░███ ██████ ██████", 13 | "-- ░███ ░███ ░██████████ ███░░███ ███░░███", 14 | "-- ░███ ░███ ░███░░░░░███░███████ ░███████", 15 | "-- ░███ ███ ░███ ░███░███░░░ ░███░░░", 16 | "-- ██████████ ███████████ ░░██████ ░░██████", 17 | "-- ░░░░░░░░░░ ░░░░░░░░░░░ ░░░░░░ ░░░░░░", 18 | "", 19 | "", 20 | '-- Type ":h dbee.txt" to learn more about the plugin.', 21 | "", 22 | '-- Report issues to: "github.com/kndndrj/nvim-dbee/issues".', 23 | "", 24 | "-- Existing users: DO NOT PANIC:", 25 | "-- Your notes and connections were moved from:", 26 | '-- "' .. vim.fn.stdpath("cache") .. '/dbee/notes" and', 27 | '-- "' .. vim.fn.stdpath("cache") .. '/dbee/persistence.json"', 28 | "-- to:", 29 | '-- "' .. vim.fn.stdpath("state") .. '/dbee/notes" and', 30 | '-- "' .. vim.fn.stdpath("state") .. '/dbee/persistence.json"', 31 | "-- Move them manually or adjust the config accordingly.", 32 | '-- see the "Breaking Changes" issue on github for more info.', 33 | } 34 | end 35 | 36 | return M 37 | -------------------------------------------------------------------------------- /lua/dbee/ui/result/progress.lua: -------------------------------------------------------------------------------- 1 | local M = {} 2 | 3 | ---@alias progress_config { text_prefix: string, spinner: string[] } 4 | 5 | --- Display an updated progress loader in the specified buffer 6 | ---@param bufnr integer -- buffer to display the progres in 7 | ---@param opts? progress_config 8 | ---@return fun() # cancel function 9 | function M.display(bufnr, opts) 10 | if not bufnr then 11 | return function() end 12 | end 13 | opts = opts or {} 14 | local text_prefix = opts.text_prefix or "Loading..." 15 | local spinner = opts.spinner or { "|", "/", "-", "\\" } 16 | 17 | local icon_index = 1 18 | local start_time = vim.fn.reltimefloat(vim.fn.reltime()) 19 | 20 | local function update() 21 | local passed_time = vim.fn.reltimefloat(vim.fn.reltime()) - start_time 22 | icon_index = (icon_index % #spinner) + 1 23 | 24 | vim.api.nvim_buf_set_option(bufnr, "modifiable", true) 25 | local line = string.format("%s %.3f seconds %s ", text_prefix, passed_time, spinner[icon_index]) 26 | vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, { line }) 27 | vim.api.nvim_buf_set_option(bufnr, "modifiable", false) 28 | end 29 | 30 | local timer = vim.fn.timer_start(100, update, { ["repeat"] = -1 }) 31 | return function() 32 | pcall(vim.fn.timer_stop, timer) 33 | end 34 | end 35 | 36 | return M 37 | -------------------------------------------------------------------------------- /lua/dbee/utils.lua: -------------------------------------------------------------------------------- 1 | local M = {} 2 | 3 | -- private variable with registered onces 4 | ---@type table 5 | local used_onces = {} 6 | 7 | ---@param id string unique id of this singleton bool 8 | ---@return boolean 9 | function M.once(id) 10 | id = id or "" 11 | 12 | if used_onces[id] then 13 | return false 14 | end 15 | 16 | used_onces[id] = true 17 | 18 | return true 19 | end 20 | 21 | -- Get cursor range of current selection 22 | ---@return integer start row 23 | ---@return integer start column 24 | ---@return integer end row 25 | ---@return integer end column 26 | function M.visual_selection() 27 | -- return to normal mode ('< and '> become available only after you exit visual mode) 28 | local key = vim.api.nvim_replace_termcodes("", true, false, true) 29 | vim.api.nvim_feedkeys(key, "x", false) 30 | 31 | local _, srow, scol, _ = unpack(vim.fn.getpos("'<")) 32 | local _, erow, ecol, _ = unpack(vim.fn.getpos("'>")) 33 | if ecol > 200000 then 34 | ecol = 20000 35 | end 36 | if srow < erow or (srow == erow and scol <= ecol) then 37 | return srow - 1, scol - 1, erow - 1, ecol 38 | else 39 | return erow - 1, ecol - 1, srow - 1, scol 40 | end 41 | end 42 | 43 | ---@param level "info"|"warn"|"error" 44 | ---@param message string 45 | ---@param subtitle? string 46 | function M.log(level, message, subtitle) 47 | -- log level 48 | local l = vim.log.levels.OFF 49 | if level == "info" then 50 | l = vim.log.levels.INFO 51 | elseif level == "warn" then 52 | l = vim.log.levels.WARN 53 | elseif level == "error" then 54 | l = vim.log.levels.ERROR 55 | end 56 | 57 | -- subtitle 58 | if subtitle then 59 | subtitle = "[" .. subtitle .. "]:" 60 | else 61 | subtitle = "" 62 | end 63 | vim.notify(subtitle .. " " .. message, l, { title = "nvim-dbee" }) 64 | end 65 | 66 | -- Gets keys of a map and sorts them by name 67 | ---@param obj table map-like table 68 | ---@return string[] 69 | function M.sorted_keys(obj) 70 | local keys = {} 71 | for k, _ in pairs(obj) do 72 | table.insert(keys, k) 73 | end 74 | table.sort(keys) 75 | return keys 76 | end 77 | 78 | -- create an autocmd that is associated with a window rather than a buffer. 79 | ---@param events string[] 80 | ---@param winid integer 81 | ---@param opts table 82 | local function create_window_autocmd(events, winid, opts) 83 | opts = opts or {} 84 | if not events or not winid or not opts.callback then 85 | return 86 | end 87 | 88 | local cb = opts.callback 89 | 90 | opts.callback = function(event) 91 | -- remove autocmd if window is closed 92 | if not vim.api.nvim_win_is_valid(winid) then 93 | vim.api.nvim_del_autocmd(event.id) 94 | return 95 | end 96 | 97 | local wid = vim.fn.bufwinid(event.buf or -1) 98 | if wid ~= winid then 99 | return 100 | end 101 | cb(event) 102 | end 103 | 104 | vim.api.nvim_create_autocmd(events, opts) 105 | end 106 | 107 | -- create an autocmd just once in a single place in code. 108 | -- If opts hold a "window" key, autocmd is defined per window rather than a buffer. 109 | -- If window and buffer are provided, this results in an error. 110 | ---@param events string[] events list as defined in nvim api 111 | ---@param opts table options as in api 112 | function M.create_singleton_autocmd(events, opts) 113 | if opts.window and opts.buffer then 114 | error("cannot register autocmd for buffer and window at the same time") 115 | end 116 | 117 | local caller_info = debug.getinfo(2) 118 | if not caller_info or not caller_info.name or not caller_info.currentline then 119 | error("could not determine function caller") 120 | end 121 | 122 | if 123 | not M.once( 124 | "autocmd_singleton_" 125 | .. caller_info.name 126 | .. caller_info.currentline 127 | .. tostring(opts.window) 128 | .. tostring(opts.buffer) 129 | ) 130 | then 131 | -- already configured 132 | return 133 | end 134 | 135 | if opts.window then 136 | local window = opts.window 137 | opts.window = nil 138 | create_window_autocmd(events, window, opts) 139 | return 140 | end 141 | 142 | vim.api.nvim_create_autocmd(events, opts) 143 | end 144 | 145 | local random_charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890" 146 | 147 | --- Generate a random string 148 | ---@return string _ random string of 10 characters 149 | function M.random_string() 150 | local function r(length) 151 | if length < 1 then 152 | return "" 153 | end 154 | 155 | local i = math.random(1, #random_charset) 156 | return r(length - 1) .. random_charset:sub(i, i) 157 | end 158 | 159 | return r(10) 160 | end 161 | 162 | return M 163 | -------------------------------------------------------------------------------- /plugin/dbee.lua: -------------------------------------------------------------------------------- 1 | if vim.g.loaded_dbee == 1 then 2 | return 3 | end 4 | vim.g.loaded_dbee = 1 5 | 6 | local COMMAND_NAME = "Dbee" 7 | 8 | ---@param args string args in form of Dbee arg1 arg2 ... 9 | ---@return string[] 10 | local function split_args(args) 11 | local stripped = args:gsub(COMMAND_NAME, "") 12 | 13 | local ret = {} 14 | for word in string.gmatch(stripped, "([^ |\t]+)") do 15 | table.insert(ret, word) 16 | end 17 | 18 | return ret 19 | end 20 | 21 | ---@param input integer[] 22 | ---@return string[] 23 | local function tostringlist(input) 24 | local ret = {} 25 | for _, elem in ipairs(input) do 26 | table.insert(ret, tostring(elem)) 27 | end 28 | return ret 29 | end 30 | 31 | -- Create user command for dbee 32 | vim.api.nvim_create_user_command(COMMAND_NAME, function(opts) 33 | local commands = { 34 | open = require("dbee").open, 35 | close = require("dbee").close, 36 | toggle = require("dbee").toggle, 37 | execute = function(args) 38 | require("dbee").execute(table.concat(args, " ")) 39 | end, 40 | store = function(args) 41 | -- args are "format", "output" and "extra_arg" 42 | if #args < 3 then 43 | error("not enough arguments, got " .. #args .. " want 3") 44 | end 45 | 46 | require("dbee").store(args[1], args[2], { extra_arg = args[3] }) 47 | end, 48 | } 49 | 50 | local args = split_args(opts.args) 51 | if #args < 1 then 52 | -- default is toggle 53 | require("dbee").toggle() 54 | return 55 | end 56 | 57 | local cmd = args[1] 58 | table.remove(args, 1) 59 | 60 | local fn = commands[cmd] 61 | if fn then 62 | fn(args) 63 | return 64 | end 65 | 66 | error("unsupported subcommand: " .. (cmd or "")) 67 | end, { 68 | nargs = "*", 69 | complete = function(_, cmdline, _) 70 | local line = split_args(cmdline) 71 | if #line < 1 then 72 | return vim.tbl_keys(commands) 73 | end 74 | 75 | if line[1] ~= "store" then 76 | return {} 77 | end 78 | 79 | local nargs = #line 80 | if nargs == 1 then 81 | -- format 82 | return { "csv", "json", "table" } 83 | elseif nargs == 2 then 84 | -- output 85 | return { "file", "yank", "buffer" } 86 | elseif nargs == 3 then 87 | -- extra_arg 88 | if line[3] == "buffer" then 89 | return tostringlist(vim.api.nvim_list_bufs()) 90 | end 91 | 92 | return 93 | end 94 | 95 | return vim.tbl_keys(commands) 96 | end, 97 | }) 98 | --------------------------------------------------------------------------------