├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── release.yaml │ └── test.yaml ├── .gitignore ├── .golangci.yml ├── .goreleaser.yml ├── LICENSE ├── Makefile ├── README.md ├── ast ├── ast.go └── astutil │ └── astutil.go ├── dialect ├── clickhouse.go ├── dialect.go ├── h2.go ├── keyword.go ├── mssql.go ├── mysql.go ├── oracle.go ├── postgresql.go ├── sqlite.go └── vertica.go ├── doc └── develop.md ├── docker-compose.yml ├── go.mod ├── go.sum ├── imgs ├── sqls-completion.gif ├── sqls-fk_joins.gif ├── sqls_document_format.gif ├── sqls_hover.gif └── sqls_signature_help.gif ├── internal ├── completer │ ├── candidates.go │ ├── completer.go │ └── completer_test.go ├── config │ ├── config.go │ ├── config_test.go │ └── testdata │ │ ├── basic.yml │ │ ├── invalid_proto.yml │ │ ├── no_connection.yml │ │ ├── no_driver.yml │ │ ├── no_dsn.yml │ │ ├── no_host.yml │ │ ├── no_path.yml │ │ ├── no_ssh_host.yml │ │ ├── no_ssh_private_key.yml │ │ ├── no_ssh_user.yml │ │ ├── no_user.yml │ │ └── oracle.yaml ├── database │ ├── cache.go │ ├── clickhouse.go │ ├── clickhouse_test.go │ ├── config.go │ ├── database.go │ ├── database_mock.go │ ├── driver.go │ ├── h2.go │ ├── mssql.go │ ├── mssql_test.go │ ├── mysql.go │ ├── oracle.go │ ├── oracle_test.go │ ├── postgresql.go │ ├── postgresql_test.go │ ├── query_type.go │ ├── query_type_test.go │ ├── scan_row.go │ ├── sqlite3.go │ ├── vertica.go │ └── worker.go ├── debug │ └── debug.go ├── formatter │ ├── formatter.go │ ├── formatter_test.go │ ├── formatutil.go │ └── testdata │ │ ├── 001.golden │ │ ├── 001.sql │ │ ├── 002.golden │ │ ├── 002.sql │ │ ├── 003.golden │ │ ├── 003.ignore │ │ └── 003.sql ├── handler │ ├── completion.go │ ├── completion_test.go │ ├── definition.go │ ├── definition_test.go │ ├── execute_command.go │ ├── execute_command_test.go │ ├── format.go │ ├── format_test.go │ ├── handler.go │ ├── handler_test.go │ ├── hover.go │ ├── hover_test.go │ ├── rename.go │ ├── rename_test.go │ ├── signature_help.go │ ├── signature_help_test.go │ └── testdata │ │ ├── format │ │ ├── select_arithmetic_expression.golden.sql │ │ ├── select_arithmetic_expression.input.sql │ │ ├── select_basic.golden.sql │ │ ├── select_basic.input.sql │ │ ├── select_group_by.golden.sql │ │ ├── select_group_by.input.sql │ │ ├── select_group_by_subquery.golden.sql │ │ ├── select_group_by_subquery.input.sql │ │ ├── select_join.golden.sql │ │ ├── select_join.input.sql │ │ ├── select_statement.golden.sql │ │ ├── select_statement.input.sql │ │ ├── select_statement_with_between.golden.sql │ │ └── select_statement_with_between.input.sql │ │ ├── format_option_space2 │ │ ├── select_basic_space2.golden.sql │ │ └── select_basic_space2.input.sql │ │ ├── format_option_space4 │ │ ├── select_basic_space4.golden.sql │ │ └── select_basic_space4.input.sql │ │ └── upper_case │ │ ├── select_basic.golden.sql │ │ └── select_basic.input.sql └── lsp │ ├── client.go │ └── lsp.go ├── main.go ├── parser ├── parser.go ├── parser_test.go └── parseutil │ ├── extract.go │ ├── extract_test.go │ ├── idenfier.go │ ├── insert.go │ ├── insert_test.go │ ├── parseutil.go │ ├── parseutil_test.go │ ├── position.go │ ├── position_test.go │ └── walk.go ├── schema.json ├── script ├── help_categories.sql ├── help_functions_mysql56.sql ├── help_functions_mysql57.sql ├── help_functions_mysql8.sql ├── help_keywords_mysql56.sql ├── help_keywords_mysql57.sql └── help_keywords_mysql8.sql └── token ├── kind.go ├── kind_string.go ├── lexer.go └── lexer_test.go /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Versions (please complete the following information):** 24 | - OS Version: [e.g. iOS] 25 | - sqls Version [e.g. 22] 26 | 27 | **Additional context** 28 | Add any other context about the problem here. 29 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | push: 4 | tags: 5 | - 'v*' 6 | env: 7 | GO_VERSION: stable 8 | 9 | jobs: 10 | build_for_linux: 11 | name: Build for Linux 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Install build dependencies 15 | run: | 16 | sudo apt-get -qq update 17 | sudo apt-get install -y --no-install-recommends \ 18 | build-essential 19 | - name: Checkout 20 | uses: actions/checkout@v3 21 | with: 22 | fetch-depth: 0 23 | - name: Setup Go 24 | uses: actions/setup-go@v4 25 | with: 26 | go-version: ${{ env.GO_VERSION }} 27 | - name: Build amd64 28 | env: 29 | CGO_ENABLED: 1 30 | GOOS: linux 31 | GOARCH: amd64 32 | run: make release 33 | - name: Archive artifacts 34 | uses: actions/upload-artifact@v3 35 | with: 36 | name: dist-linux 37 | path: sqls-linux-*.zip 38 | 39 | build_for_macos: 40 | name: Build for MacOS 41 | runs-on: macos-latest 42 | steps: 43 | - name: Install build dependencies 44 | run: brew install coreutils 45 | - name: Checkout 46 | uses: actions/checkout@v3 47 | with: 48 | fetch-depth: 0 49 | - name: Setup Go 50 | uses: actions/setup-go@v4 51 | with: 52 | go-version: ${{ env.GO_VERSION }} 53 | - name: Build amd64 54 | env: 55 | CGO_ENABLED: 1 56 | GOOS: darwin 57 | GOARCH: amd64 58 | run: make release 59 | - name: Archive artifacts 60 | uses: actions/upload-artifact@v3 61 | with: 62 | name: dist-darwin 63 | path: sqls-darwin-*.zip 64 | 65 | build_for_windows: 66 | name: Build for Windows 67 | runs-on: windows-latest 68 | steps: 69 | - name: Install build dependencies 70 | run: choco install zip 71 | - name: Checkout 72 | uses: actions/checkout@v3 73 | with: 74 | fetch-depth: 0 75 | - name: Setup Go 76 | uses: actions/setup-go@v4 77 | with: 78 | go-version: ${{ env.GO_VERSION }} 79 | - name: Build amd64 80 | shell: bash 81 | env: 82 | CGO_ENABLED: 1 83 | GOOS: windows 84 | GOARCH: amd64 85 | run: make release 86 | - name: Archive artifacts 87 | uses: actions/upload-artifact@v3 88 | with: 89 | name: dist-windows 90 | path: sqls-windows-*.zip 91 | 92 | release: 93 | name: Draft Release 94 | needs: 95 | - build_for_linux 96 | - build_for_macos 97 | - build_for_windows 98 | runs-on: ubuntu-latest 99 | steps: 100 | - name: Download artifacts 101 | uses: actions/download-artifact@v3 102 | - name: Release 103 | uses: softprops/action-gh-release@v1 104 | if: startsWith(github.ref, 'refs/tags/') 105 | with: 106 | name: sqls ${{ github.ref_name }} 107 | token: ${{ secrets.GITHUB_TOKEN }} 108 | draft: true 109 | generate_release_notes: true 110 | files: dist-*/sqls*.* 111 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | pull_request: 5 | push: 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v3 12 | - name: Setup Go 13 | uses: actions/setup-go@v3 14 | with: 15 | go-version: 1.21 16 | - name: Lint 17 | uses: golangci/golangci-lint-action@v3 18 | with: 19 | version: v1.54 20 | - name: Test 21 | run: go test -coverprofile coverage.out -covermode atomic ./... 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/go 2 | # Edit at https://www.gitignore.io/?templates=go 3 | 4 | ### Go ### 5 | # Binaries for programs and plugins 6 | *.exe 7 | *.exe~ 8 | *.dll 9 | *.so 10 | *.dylib 11 | 12 | # Test binary, built with `go test -c` 13 | *.test 14 | 15 | # Output of the go coverage tool, specifically when used with LiteIDE 16 | *.out 17 | 18 | # Dependency directories (remove the comment below to include it) 19 | # vendor/ 20 | 21 | ### Go Patch ### 22 | /vendor/ 23 | /Godeps/ 24 | 25 | # End of https://www.gitignore.io/api/go 26 | 27 | # docker 28 | docker 29 | .docker 30 | 31 | # build dist 32 | sqls 33 | dist/ 34 | 35 | # direnv 36 | .envrc 37 | 38 | # vim 39 | tags 40 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | run: 2 | tests: false 3 | timeout: 5m 4 | linters: 5 | disable: 6 | - maligned 7 | - prealloc 8 | - rowserrcheck 9 | disable-all: false 10 | presets: 11 | - bugs 12 | fast: false 13 | issues: 14 | exclude-rules: 15 | - linters: 16 | - gosec 17 | text: "G106:" 18 | -------------------------------------------------------------------------------- /.goreleaser.yml: -------------------------------------------------------------------------------- 1 | # This is an example goreleaser.yaml file with some sane defaults. 2 | # Make sure to check the documentation at http://goreleaser.com 3 | before: 4 | hooks: 5 | # You may remove this if you don't use go modules. 6 | - go mod download 7 | builds: 8 | - id: sqls-darwin-amd64 9 | ldflags: 10 | - -s -w 11 | - -X main.version={{.Version}} 12 | - -X main.revision={{.ShortCommit}} 13 | env: 14 | - CGO_ENABLED=1 15 | - CC=o64-clang 16 | - CXX=o64-clang++ 17 | goos: 18 | - darwin 19 | goarch: 20 | - amd64 21 | - id: sqls-linux-amd64 22 | ldflags: 23 | - -s -w 24 | - -X main.version={{.Version}} 25 | - -X main.revision={{.ShortCommit}} 26 | env: 27 | - CGO_ENABLED=1 28 | goos: 29 | - linux 30 | goarch: 31 | - amd64 32 | - id: sqls-windows-amd64 33 | ldflags: 34 | - -s -w 35 | - -X main.version={{.Version}} 36 | - -X main.revision={{.ShortCommit}} 37 | env: 38 | - CGO_ENABLED=1 39 | - CC=x86_64-w64-mingw32-gcc 40 | - CXX=x86_64-w64-mingw32-g++ 41 | goos: 42 | - windows 43 | goarch: 44 | - amd64 45 | archives: 46 | - replacements: 47 | darwin: Darwin 48 | linux: Linux 49 | windows: Windows 50 | amd64: x86_64 51 | checksum: 52 | name_template: 'checksums.txt' 53 | snapshot: 54 | name_template: "{{ .Tag }}" 55 | changelog: 56 | sort: asc 57 | filters: 58 | exclude: 59 | - '^docs:' 60 | - '^test:' 61 | release: 62 | github: 63 | owner: sqls-server 64 | name: sqls 65 | # If set to auto, will mark the release as not ready for production 66 | # in case there is an indicator for this in the tag e.g. v1.0.0-rc1 67 | # If set to true, will mark the release as not ready for production. 68 | # Default is false. 69 | prerelease: auto 70 | # You can change the name of the release. 71 | # Default is `{{.Tag}}` 72 | name_template: "{{.ProjectName}} v{{.Version}}" 73 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Toshikazu Ohashi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | BIN := sqls 2 | ifeq ($(OS),Windows_NT) 3 | BIN := $(BIN).exe 4 | endif 5 | VERSION := $$(make -s show-version) 6 | CURRENT_REVISION := $(shell git rev-parse --short HEAD) 7 | BUILD_LDFLAGS := "-s -w -X main.revision=$(CURRENT_REVISION)" 8 | GOOS := $(shell go env GOOS) 9 | GOBIN ?= $(shell go env GOPATH)/bin 10 | export GO111MODULE=on 11 | 12 | .PHONY: all 13 | all: clean build 14 | 15 | .PHONY: build 16 | build: 17 | go build -ldflags=$(BUILD_LDFLAGS) -o $(BIN) . 18 | 19 | .PHONY: release 20 | release: 21 | go build -ldflags=$(BUILD_LDFLAGS) -o $(BIN) . 22 | zip -r sqls-$(GOOS)-$(VERSION).zip $(BIN) 23 | 24 | .PHONY: install 25 | install: 26 | go install -ldflags=$(BUILD_LDFLAGS) . 27 | 28 | .PHONY: show-version 29 | show-version: $(GOBIN)/gobump 30 | gobump show -r . 31 | 32 | $(GOBIN)/gobump: 33 | go install github.com/x-motemen/gobump/cmd/gobump@latest 34 | 35 | .PHONY: test 36 | test: build 37 | go test -v ./... 38 | 39 | .PHONY: clean 40 | clean: 41 | go clean 42 | 43 | .PHONY: bump 44 | bump: $(GOBIN)/gobump 45 | ifneq ($(shell git status --porcelain),) 46 | $(error git workspace is dirty) 47 | endif 48 | ifneq ($(shell git rev-parse --abbrev-ref HEAD),master) 49 | $(error current branch is not master) 50 | endif 51 | @gobump up -w . 52 | git commit -am "bump up version to $(VERSION)" 53 | git tag "v$(VERSION)" 54 | git push origin master 55 | git push origin "refs/tags/v$(VERSION)" 56 | -------------------------------------------------------------------------------- /ast/astutil/astutil.go: -------------------------------------------------------------------------------- 1 | package astutil 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/sqls-server/sqls/ast" 8 | "github.com/sqls-server/sqls/dialect" 9 | "github.com/sqls-server/sqls/token" 10 | ) 11 | 12 | type NodeMatcher struct { 13 | NodeTypes []ast.NodeType 14 | ExpectTokens []token.Kind 15 | ExpectSQLType []dialect.KeywordKind 16 | ExpectKeyword []string 17 | } 18 | 19 | func (nm *NodeMatcher) IsMatchNodeTypes(node ast.Node) bool { 20 | if nm.NodeTypes != nil { 21 | for _, expect := range nm.NodeTypes { 22 | if expect == node.Type() { 23 | return true 24 | } 25 | } 26 | } 27 | return false 28 | } 29 | 30 | func (nm *NodeMatcher) IsMatchTokens(tok *ast.SQLToken) bool { 31 | if nm.ExpectTokens != nil { 32 | for _, expect := range nm.ExpectTokens { 33 | if tok.MatchKind(expect) { 34 | return true 35 | } 36 | } 37 | } 38 | return false 39 | } 40 | 41 | func (nm *NodeMatcher) IsMatchSQLType(tok *ast.SQLToken) bool { 42 | if nm.ExpectSQLType != nil { 43 | for _, expect := range nm.ExpectSQLType { 44 | if tok.MatchSQLKind(expect) { 45 | return true 46 | } 47 | } 48 | } 49 | return false 50 | } 51 | 52 | func (nm *NodeMatcher) IsMatchKeyword(node ast.Node) bool { 53 | if nm.ExpectKeyword != nil { 54 | for _, expect := range nm.ExpectKeyword { 55 | if strings.EqualFold(expect, node.String()) { 56 | return true 57 | } 58 | } 59 | } 60 | return false 61 | } 62 | 63 | func (nm *NodeMatcher) IsMatch(node ast.Node) bool { 64 | // For node object 65 | if nm.IsMatchNodeTypes(node) { 66 | return true 67 | } 68 | if nm.IsMatchKeyword(node) { 69 | return true 70 | } 71 | if _, ok := node.(ast.TokenList); ok { 72 | return false 73 | } 74 | // For token object 75 | tok, ok := node.(ast.Token) 76 | if !ok { 77 | panic(fmt.Sprintf("invalid type. not has Token, got=(type: %T, value: %#v)", node, node.String())) 78 | } 79 | sqlTok := tok.GetToken() 80 | if nm.IsMatchTokens(sqlTok) || nm.IsMatchSQLType(sqlTok) { 81 | return true 82 | } 83 | return false 84 | } 85 | 86 | func isWhitespace(node ast.Node) bool { 87 | tok, ok := node.(ast.Token) 88 | if !ok { 89 | return false 90 | } 91 | if tok.GetToken().MatchKind(token.Whitespace) { 92 | return true 93 | } 94 | return false 95 | } 96 | 97 | type NodeReader struct { 98 | Node ast.TokenList 99 | CurNode ast.Node 100 | Index int 101 | } 102 | 103 | func NewNodeReader(list ast.TokenList) *NodeReader { 104 | return &NodeReader{ 105 | Node: list, 106 | } 107 | } 108 | 109 | func (nr *NodeReader) CopyReader() *NodeReader { 110 | return &NodeReader{ 111 | Node: nr.Node, 112 | Index: nr.Index, 113 | } 114 | } 115 | 116 | func (nr *NodeReader) Replace(add ast.Node, index int) { 117 | list := nr.Node.GetTokens() 118 | list = append(list[:index], list[index:]...) 119 | list[index] = add 120 | nr.Node.SetTokens(list) 121 | } 122 | 123 | func (nr *NodeReader) NodesWithRange(startIndex, endIndex int) []ast.Node { 124 | return nr.Node.GetTokens()[startIndex:endIndex] 125 | } 126 | 127 | func (nr *NodeReader) hasNext() bool { 128 | return nr.Index < len(nr.Node.GetTokens()) 129 | } 130 | 131 | func (nr *NodeReader) hasPrev() bool { 132 | return 0 < nr.Index 133 | } 134 | 135 | func (nr *NodeReader) NextNode(ignoreWhiteSpace bool) bool { 136 | if !nr.hasNext() { 137 | return false 138 | } 139 | nr.CurNode = nr.Node.GetTokens()[nr.Index] 140 | nr.Index++ 141 | 142 | if ignoreWhiteSpace && isWhitespace(nr.CurNode) { 143 | return nr.NextNode(ignoreWhiteSpace) 144 | } 145 | return true 146 | } 147 | 148 | func (nr *NodeReader) prev(ignoreWhiteSpace bool) bool { 149 | if !nr.hasPrev() { 150 | return false 151 | } 152 | nr.Index-- 153 | nr.CurNode = nr.Node.GetTokens()[nr.Index] 154 | 155 | if ignoreWhiteSpace && isWhitespace(nr.CurNode) { 156 | return nr.prev(ignoreWhiteSpace) 157 | } 158 | return true 159 | } 160 | 161 | func (nr *NodeReader) CurNodeIs(nm NodeMatcher) bool { 162 | if nr.CurNode != nil { 163 | if nm.IsMatch(nr.CurNode) { 164 | return true 165 | } 166 | } 167 | return false 168 | } 169 | 170 | func IsEnclose(node ast.Node, pos token.Pos) bool { 171 | if 0 <= token.ComparePos(pos, node.Pos()) && 0 >= token.ComparePos(pos, node.End()) { 172 | return true 173 | } 174 | return false 175 | } 176 | 177 | func (nr *NodeReader) CurNodeEncloseIs(pos token.Pos) bool { 178 | if nr.CurNode != nil { 179 | return IsEnclose(nr.CurNode, pos) 180 | } 181 | return false 182 | } 183 | 184 | func (nr *NodeReader) PeekNodeEncloseIs(pos token.Pos) bool { 185 | _, peekNode := nr.PeekNode(false) 186 | if peekNode != nil { 187 | return IsEnclose(peekNode, pos) 188 | } 189 | return false 190 | } 191 | 192 | func (nr *NodeReader) PeekNode(ignoreWhiteSpace bool) (int, ast.Node) { 193 | tmpReader := nr.CopyReader() 194 | for tmpReader.hasNext() { 195 | index := tmpReader.Index 196 | node := tmpReader.Node.GetTokens()[index] 197 | 198 | if ignoreWhiteSpace { 199 | if !isWhitespace(node) { 200 | return index, node 201 | } 202 | } else { 203 | return index, node 204 | } 205 | tmpReader.NextNode(false) 206 | } 207 | return 0, nil 208 | } 209 | 210 | func (nr *NodeReader) PeekNodeIs(ignoreWhiteSpace bool, nm NodeMatcher) bool { 211 | _, node := nr.PeekNode(ignoreWhiteSpace) 212 | if node != nil { 213 | if nm.IsMatch(node) { 214 | return true 215 | } 216 | } 217 | return false 218 | } 219 | 220 | func (nr *NodeReader) TailNode() (int, ast.Node) { 221 | var ( 222 | index int 223 | node ast.Node 224 | ) 225 | 226 | tmpReader := nr.CopyReader() 227 | for { 228 | index = tmpReader.Index 229 | node = tmpReader.CurNode 230 | if !tmpReader.hasNext() { 231 | break 232 | } 233 | tmpReader.NextNode(false) 234 | } 235 | return index, node 236 | } 237 | 238 | func (nr *NodeReader) FindNode(ignoreWhiteSpace bool, nm NodeMatcher) (*NodeReader, ast.Node) { 239 | tmpReader := nr.CopyReader() 240 | for tmpReader.hasNext() { 241 | node := tmpReader.Node.GetTokens()[tmpReader.Index] 242 | 243 | // For node object 244 | if nm.IsMatchNodeTypes(node) { 245 | return tmpReader, node 246 | } 247 | if _, ok := tmpReader.CurNode.(ast.TokenList); ok { 248 | continue 249 | } 250 | // For token object 251 | tok, _ := nr.CurNode.(ast.Token) 252 | sqlTok := tok.GetToken() 253 | if nm.IsMatchTokens(sqlTok) || nm.IsMatchSQLType(sqlTok) || nm.IsMatchKeyword(sqlTok) { 254 | return tmpReader, node 255 | } 256 | tmpReader.NextNode(ignoreWhiteSpace) 257 | } 258 | return nil, nil 259 | } 260 | 261 | func (nr *NodeReader) FindRecursive(matcher NodeMatcher) []ast.Node { 262 | matches := []ast.Node{} 263 | for nr.NextNode(false) { 264 | if nr.CurNodeIs(matcher) { 265 | matches = append(matches, nr.CurNode) 266 | } 267 | if list, ok := nr.CurNode.(ast.TokenList); ok { 268 | newReader := NewNodeReader(list) 269 | matches = append(matches, newReader.FindRecursive(matcher)...) 270 | } 271 | } 272 | return matches 273 | } 274 | 275 | func (nr *NodeReader) PrevNode(ignoreWhiteSpace bool) (int, ast.Node) { 276 | if !nr.hasPrev() { 277 | return 0, nil 278 | } 279 | tmpReader := nr.CopyReader() 280 | tmpReader.prev(false) 281 | 282 | for tmpReader.prev(ignoreWhiteSpace) { 283 | index := tmpReader.Index 284 | node := tmpReader.CurNode 285 | 286 | if ignoreWhiteSpace { 287 | if !isWhitespace(node) { 288 | return index, node 289 | } 290 | } else { 291 | return index, node 292 | } 293 | 294 | if !tmpReader.hasPrev() { 295 | break 296 | } 297 | } 298 | return 0, nil 299 | } 300 | 301 | func (nr *NodeReader) PrevNodeIs(ignoreWhiteSpace bool, nm NodeMatcher) bool { 302 | _, node := nr.PrevNode(ignoreWhiteSpace) 303 | if node != nil { 304 | if nm.IsMatch(node) { 305 | return true 306 | } 307 | } 308 | return false 309 | } 310 | -------------------------------------------------------------------------------- /dialect/clickhouse.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | // TODO(patrick.pichler): figure out real keywords as those are copied over from postgres 4 | var clickhouseKeywords = []string{ 5 | "ABORT", 6 | "ABSOLUTE", 7 | "ACCESS", 8 | "ACTION", 9 | "ADD", 10 | "ADMIN", 11 | "AFTER", 12 | "AGGREGATE", 13 | "ALL", 14 | "ALSO", 15 | "ALTER", 16 | "ALWAYS", 17 | "ANALYSE", 18 | "ANALYZE", 19 | "AND", 20 | "ANY", 21 | "ARRAY", 22 | "AS", 23 | "ASC", 24 | "ASSERTION", 25 | "ASSIGNMENT", 26 | "ASYMMETRIC", 27 | "AT", 28 | "ATTACH", 29 | "ATTRIBUTE", 30 | "AUTHORIZATION", 31 | "BACKWARD", 32 | "BEFORE", 33 | "BEGIN", 34 | "BETWEEN", 35 | "BIGINT", 36 | "BINARY", 37 | "BIT", 38 | "BOOLEAN", 39 | "BOTH", 40 | "BY", 41 | "CACHE", 42 | "CALL", 43 | "CALLED", 44 | "CASCADE", 45 | "CASCADED", 46 | "CASE", 47 | "CAST", 48 | "CATALOG", 49 | "CHAIN", 50 | "CHAR", 51 | "CHARACTER", 52 | "CHARACTERISTICS", 53 | "CHECK", 54 | "CHECKPOINT", 55 | "CLASS", 56 | "CLOSE", 57 | "CLUSTER", 58 | "COALESCE", 59 | "COLLATE", 60 | "COLLATION", 61 | "COLUMN", 62 | "COLUMNS", 63 | "COMMENT", 64 | "COMMENTS", 65 | "COMMIT", 66 | "COMMITTED", 67 | "CONCURRENTLY", 68 | "CONFIGURATION", 69 | "CONFLICT", 70 | "CONNECTION", 71 | "CONSTRAINT", 72 | "CONSTRAINTS", 73 | "CONTENT", 74 | "CONTINUE", 75 | "CONVERSION", 76 | "COPY", 77 | "COST", 78 | "CREATE", 79 | "CROSS", 80 | "CSV", 81 | "CUBE", 82 | "CURRENT", 83 | "CURRENT_CATALOG", 84 | "CURRENT_DATE", 85 | "CURRENT_ROLE", 86 | "CURRENT_SCHEMA", 87 | "CURRENT_TIME", 88 | "CURRENT_TIMESTAMP", 89 | "CURRENT_USER", 90 | "CURSOR", 91 | "CYCLE", 92 | "DATA", 93 | "DATABASE", 94 | "DAY", 95 | "DEALLOCATE", 96 | "DEC", 97 | "DECIMAL", 98 | "DECLARE", 99 | "DEFAULT", 100 | "DEFAULTS", 101 | "DEFERRABLE", 102 | "DEFERRED", 103 | "DEFINER", 104 | "DELETE", 105 | "DELIMITER", 106 | "DELIMITERS", 107 | "DEPENDS", 108 | "DESC", 109 | "DETACH", 110 | "DICTIONARY", 111 | "DISABLE", 112 | "DISCARD", 113 | "DISTINCT", 114 | "DO", 115 | "DOCUMENT", 116 | "DOMAIN", 117 | "DOUBLE", 118 | "DROP", 119 | "EACH", 120 | "ELSE", 121 | "ENABLE", 122 | "ENCODING", 123 | "ENCRYPTED", 124 | "END", 125 | "ENUM", 126 | "ESCAPE", 127 | "EVENT", 128 | "EXCEPT", 129 | "EXCLUDE", 130 | "EXCLUDING", 131 | "EXCLUSIVE", 132 | "EXECUTE", 133 | "EXISTS", 134 | "EXPLAIN", 135 | "EXPRESSION", 136 | "EXTENSION", 137 | "EXTERNAL", 138 | "EXTRACT", 139 | "FALSE", 140 | "FAMILY", 141 | "FETCH", 142 | "FILTER", 143 | "FIRST", 144 | "FLOAT", 145 | "FOLLOWING", 146 | "FOR", 147 | "FORCE", 148 | "FOREIGN", 149 | "FORWARD", 150 | "FREEZE", 151 | "FROM", 152 | "FULL", 153 | "FUNCTION", 154 | "FUNCTIONS", 155 | "GENERATED", 156 | "GLOBAL", 157 | "GRANT", 158 | "GRANTED", 159 | "GREATEST", 160 | "GROUP", 161 | "GROUPING", 162 | "GROUPS", 163 | "HANDLER", 164 | "HAVING", 165 | "HEADER", 166 | "HOLD", 167 | "HOUR", 168 | "IDENTITY", 169 | "IF", 170 | "ILIKE", 171 | "IMMEDIATE", 172 | "IMMUTABLE", 173 | "IMPLICIT", 174 | "IMPORT", 175 | "IN", 176 | "INCLUDE", 177 | "INCLUDING", 178 | "INCREMENT", 179 | "INDEX", 180 | "INDEXES", 181 | "INHERIT", 182 | "INHERITS", 183 | "INITIALLY", 184 | "INLINE", 185 | "INNER", 186 | "INOUT", 187 | "INPUT", 188 | "INSENSITIVE", 189 | "INSERT", 190 | "INSTEAD", 191 | "INT", 192 | "INTEGER", 193 | "INTERSECT", 194 | "INTERVAL", 195 | "INTO", 196 | "INVOKER", 197 | "IS", 198 | "ISNULL", 199 | "ISOLATION", 200 | "JOIN", 201 | "KEY", 202 | "LABEL", 203 | "LANGUAGE", 204 | "LARGE", 205 | "LAST", 206 | "LATERAL", 207 | "LEADING", 208 | "LEAKPROOF", 209 | "LEAST", 210 | "LEFT", 211 | "LEVEL", 212 | "LIKE", 213 | "LIMIT", 214 | "LISTEN", 215 | "LOAD", 216 | "LOCAL", 217 | "LOCALTIME", 218 | "LOCALTIMESTAMP", 219 | "LOCATION", 220 | "LOCK", 221 | "LOCKED", 222 | "LOGGED", 223 | "MAPPING", 224 | "MATCH", 225 | "MATERIALIZED", 226 | "MAXVALUE", 227 | "METHOD", 228 | "MINUTE", 229 | "MINVALUE", 230 | "MODE", 231 | "MONTH", 232 | "MOVE", 233 | "NAME", 234 | "NAMES", 235 | "NATIONAL", 236 | "NATURAL", 237 | "NCHAR", 238 | "NEW", 239 | "NEXT", 240 | "NFC", 241 | "NFD", 242 | "NFKC", 243 | "NFKD", 244 | "NO", 245 | "NONE", 246 | "NORMALIZE", 247 | "NORMALIZED", 248 | "NOT", 249 | "NOTHING", 250 | "NOTIFY", 251 | "NOTNULL", 252 | "NOWAIT", 253 | "NULL", 254 | "NULLIF", 255 | "NULLS", 256 | "NUMERIC", 257 | "OBJECT", 258 | "OF", 259 | "OFF", 260 | "OFFSET", 261 | "OIDS", 262 | "OLD", 263 | "ON", 264 | "ONLY", 265 | "OPERATOR", 266 | "OPTION", 267 | "OPTIONS", 268 | "OR", 269 | "ORDER", 270 | "ORDINALITY", 271 | "OTHERS", 272 | "OUT", 273 | "OUTER", 274 | "OVER", 275 | "OVERLAPS", 276 | "OVERLAY", 277 | "OVERRIDING", 278 | "OWNED", 279 | "OWNER", 280 | "PARALLEL", 281 | "PARSER", 282 | "PARTIAL", 283 | "PARTITION", 284 | "PASSING", 285 | "PASSWORD", 286 | "PLACING", 287 | "PLANS", 288 | "POLICY", 289 | "POSITION", 290 | "PRECEDING", 291 | "PRECISION", 292 | "PREPARE", 293 | "PREPARED", 294 | "PRESERVE", 295 | "PRIMARY", 296 | "PRIOR", 297 | "PRIVILEGES", 298 | "PROCEDURAL", 299 | "PROCEDURE", 300 | "PROCEDURES", 301 | "PROGRAM", 302 | "PUBLICATION", 303 | "QUOTE", 304 | "RANGE", 305 | "READ", 306 | "REAL", 307 | "REASSIGN", 308 | "RECHECK", 309 | "RECURSIVE", 310 | "REF", 311 | "REFERENCES", 312 | "REFERENCING", 313 | "REFRESH", 314 | "REINDEX", 315 | "RELATIVE", 316 | "RELEASE", 317 | "RENAME", 318 | "REPEATABLE", 319 | "REPLACE", 320 | "REPLICA", 321 | "RESET", 322 | "RESTART", 323 | "RESTRICT", 324 | "RETURNING", 325 | "RETURNS", 326 | "REVOKE", 327 | "RIGHT", 328 | "ROLE", 329 | "ROLLBACK", 330 | "ROLLUP", 331 | "ROUTINE", 332 | "ROUTINES", 333 | "ROW", 334 | "ROWS", 335 | "RULE", 336 | "SAVEPOINT", 337 | "SCHEMA", 338 | "SCHEMAS", 339 | "SCROLL", 340 | "SEARCH", 341 | "SECOND", 342 | "SECURITY", 343 | "SELECT", 344 | "SEQUENCE", 345 | "SEQUENCES", 346 | "SERIALIZABLE", 347 | "SERVER", 348 | "SESSION", 349 | "SESSION_USER", 350 | "SET", 351 | "SETOF", 352 | "SETS", 353 | "SHARE", 354 | "SHOW", 355 | "SIMILAR", 356 | "SIMPLE", 357 | "SKIP", 358 | "SMALLINT", 359 | "SNAPSHOT", 360 | "SOME", 361 | "SQL", 362 | "STABLE", 363 | "STANDALONE", 364 | "START", 365 | "STATEMENT", 366 | "STATISTICS", 367 | "STDIN", 368 | "STDOUT", 369 | "STORAGE", 370 | "STORED", 371 | "STRICT", 372 | "STRIP", 373 | "SUBSCRIPTION", 374 | "SUBSTRING", 375 | "SUPPORT", 376 | "SYMMETRIC", 377 | "SYSID", 378 | "SYSTEM", 379 | "TABLE", 380 | "TABLES", 381 | "TABLESAMPLE", 382 | "TABLESPACE", 383 | "TEMP", 384 | "TEMPLATE", 385 | "TEMPORARY", 386 | "TEXT", 387 | "THEN", 388 | "TIES", 389 | "TIME", 390 | "TIMESTAMP", 391 | "TO", 392 | "TRAILING", 393 | "TRANSACTION", 394 | "TRANSFORM", 395 | "TREAT", 396 | "TRIGGER", 397 | "TRIM", 398 | "TRUE", 399 | "TRUNCATE", 400 | "TRUSTED", 401 | "TYPE", 402 | "TYPES", 403 | "UESCAPE", 404 | "UNBOUNDED", 405 | "UNCOMMITTED", 406 | "UNENCRYPTED", 407 | "UNION", 408 | "UNIQUE", 409 | "UNKNOWN", 410 | "UNLISTEN", 411 | "UNLOGGED", 412 | "UNTIL", 413 | "UPDATE", 414 | "USER", 415 | "USING", 416 | "VACUUM", 417 | "VALID", 418 | "VALIDATE", 419 | "VALIDATOR", 420 | "VALUE", 421 | "VALUES", 422 | "VARCHAR", 423 | "VARIADIC", 424 | "VARYING", 425 | "VERBOSE", 426 | "VERSION", 427 | "VIEW", 428 | "VIEWS", 429 | "VOLATILE", 430 | "WHEN", 431 | "WHERE", 432 | "WHITESPACE", 433 | "WINDOW", 434 | "WITH", 435 | "WITHIN", 436 | "WITHOUT", 437 | "WORK", 438 | "WRAPPER", 439 | "WRITE", 440 | "XML", 441 | "XMLATTRIBUTES", 442 | "XMLCONCAT", 443 | "XMLELEMENT", 444 | "XMLEXISTS", 445 | "XMLFOREST", 446 | "XMLNAMESPACES", 447 | "XMLPARSE", 448 | "XMLPI", 449 | "XMLROOT", 450 | "XMLSERIALIZE", 451 | "XMLTABLE", 452 | "YEAR", 453 | "YES", 454 | "ZONE", 455 | } 456 | -------------------------------------------------------------------------------- /dialect/dialect.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | type Dialect interface { 4 | IsIdentifierStart(r rune) bool 5 | IsIdentifierPart(r rune) bool 6 | IsDelimitedIdentifierStart(r rune) bool 7 | IsPlaceHolderStart(r rune) bool 8 | IsPlaceHolderPart(r rune) bool 9 | } 10 | 11 | type GenericSQLDialect struct { 12 | } 13 | 14 | func (*GenericSQLDialect) IsIdentifierStart(r rune) bool { 15 | return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || r == '@' 16 | } 17 | 18 | func (*GenericSQLDialect) IsIdentifierPart(r rune) bool { 19 | return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '@' || r == '_' 20 | } 21 | 22 | func (*GenericSQLDialect) IsDelimitedIdentifierStart(r rune) bool { 23 | return r == '"' || r == '`' 24 | } 25 | 26 | func (*GenericSQLDialect) IsPlaceHolderStart(r rune) bool { 27 | return r == '$' 28 | } 29 | 30 | func (*GenericSQLDialect) IsPlaceHolderPart(r rune) bool { 31 | return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') 32 | } 33 | 34 | var _ Dialect = &GenericSQLDialect{} 35 | -------------------------------------------------------------------------------- /dialect/h2.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | var h2Keywords = []string{ 4 | "ALL", 5 | "AND", 6 | "ANY", 7 | "ARRAY", 8 | "AS", 9 | "ASYMMETRIC", 10 | "AUTHORIZATION", 11 | "BETWEEN", 12 | "BOTH", 13 | "CASE", 14 | "CAST", 15 | "CHECK", 16 | "CONSTRAINT", 17 | "CROSS", 18 | "CURRENT_CATALOG", 19 | "CURRENT_DATE", 20 | "CURRENT_PATH", 21 | "CURRENT_ROLE", 22 | "CURRENT_SCHEMA", 23 | "CURRENT_TIME", 24 | "CURRENT_TIMESTAMP", 25 | "CURRENT_USER", 26 | "DAY", 27 | "DEFAULT", 28 | "DISTINCT", 29 | "ELSE", 30 | "END", 31 | "EXCEPT", 32 | "EXISTS", 33 | "FALSE", 34 | "FETCH", 35 | "FILTER", 36 | "FOR", 37 | "FOREIGN", 38 | "FROM", 39 | "FULL", 40 | "GROUP", 41 | "GROUPS", 42 | "HAVING", 43 | "HOUR", 44 | "IF", 45 | "ILIKE", 46 | "IN", 47 | "INNER", 48 | "INTERSECT", 49 | "INTERVAL", 50 | "IS", 51 | "JOIN", 52 | "KEY", 53 | "LEADING", 54 | "LEFT", 55 | "LIKE", 56 | "LIMIT", 57 | "LOCALTIME", 58 | "LOCALTIMESTAMP", 59 | "MINUS", 60 | "MINUTE", 61 | "MONTH", 62 | "NATURAL", 63 | "NOT", 64 | "NULL", 65 | "OFFSET", 66 | "ON", 67 | "OR", 68 | "ORDER", 69 | "OVER", 70 | "PARTITION", 71 | "PRIMARY", 72 | "QUALIFY", 73 | "RANGE", 74 | "REGEXP", 75 | "RIGHT", 76 | "ROW", 77 | "ROWNUM", 78 | "ROWS", 79 | "SECOND", 80 | "SELECT", 81 | "SESSION_USER", 82 | "SET", 83 | "SOME", 84 | "SYMMETRIC", 85 | "SYSTEM_USER", 86 | "TABLE", 87 | "TO", 88 | "TOP", 89 | "CS", 90 | "TRAILING", 91 | "TRUE", 92 | "UESCAPE", 93 | "UNION", 94 | "UNIQUE", 95 | "UNKNOWN", 96 | "USER", 97 | "USING", 98 | "VALUE", 99 | "VALUES", 100 | "WHEN", 101 | "WHERE", 102 | "WINDOW", 103 | "WITH", 104 | "YEAR", 105 | "_ROWID_", 106 | } 107 | 108 | -------------------------------------------------------------------------------- /dialect/mssql.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | // https://docs.microsoft.com/en-us/sql/t-sql/language-elements/reserved-keywords-transact-sql 4 | var mssqlKeywords = []string{ 5 | "ADD", 6 | "ALL", 7 | "ALTER", 8 | "AND", 9 | "ANY", 10 | "AS", 11 | "ASC", 12 | "AUTHORIZATION", 13 | "BACKUP", 14 | "BEGIN", 15 | "BETWEEN", 16 | "BREAK", 17 | "BROWSE", 18 | "BULK", 19 | "BY", 20 | "CASCADE", 21 | "CASE", 22 | "CHECK", 23 | "CHECKPOINT", 24 | "CLOSE", 25 | "CLUSTERED", 26 | "COALESCE", 27 | "COLLATE", 28 | "COLUMN", 29 | "COMMIT", 30 | "COMPUTE", 31 | "CONSTRAINT", 32 | "CONTAINS", 33 | "CONTAINSTABLE", 34 | "CONTINUE", 35 | "CONVERT", 36 | "CREATE", 37 | "CROSS", 38 | "CURRENT", 39 | "CURRENT_DATE", 40 | "CURRENT_TIME", 41 | "CURRENT_TIMESTAMP", 42 | "CURRENT_USER", 43 | "CURSOR", 44 | "DATABASE", 45 | "DBCC", 46 | "DEALLOCATE", 47 | "DECLARE", 48 | "DEFAULT", 49 | "DELETE", 50 | "DENY", 51 | "DESC", 52 | "DISK", 53 | "DISTINCT", 54 | "DISTRIBUTED", 55 | "DOUBLE", 56 | "DROP", 57 | "DUMP", 58 | "ELSE", 59 | "END", 60 | "ERRLVL", 61 | "ESCAPE", 62 | "EXCEPT", 63 | "EXEC", 64 | "EXECUTE", 65 | "EXISTS", 66 | "EXIT", 67 | "EXTERNAL", 68 | "FETCH", 69 | "FILE", 70 | "FILLFACTOR", 71 | "FOR", 72 | "FOREIGN", 73 | "FREETEXT", 74 | "FREETEXTTABLE", 75 | "FROM", 76 | "FULL", 77 | "FUNCTION", 78 | "GOTO", 79 | "GRANT", 80 | "GROUP", 81 | "HAVING", 82 | "HOLDLOCK", 83 | "IDENTITY", 84 | "IDENTITYCOL", 85 | "IDENTITY_INSERT", 86 | "IF", 87 | "IN", 88 | "INDEX", 89 | "INNER", 90 | "INSERT", 91 | "INTERSECT", 92 | "INTO", 93 | "IS", 94 | "JOIN", 95 | "KEY", 96 | "KILL", 97 | "LEFT", 98 | "LIKE", 99 | "LINENO", 100 | "LOAD", 101 | "MERGE", 102 | "NATIONAL", 103 | "NOCHECK", 104 | "NONCLUSTERED", 105 | "NOT", 106 | "NULL", 107 | "NULLIF", 108 | "OF", 109 | "OFF", 110 | "OFFSETS", 111 | "ON", 112 | "OPEN", 113 | "OPENDATASOURCE", 114 | "OPENQUERY", 115 | "OPENROWSET", 116 | "OPENXML", 117 | "OPTION", 118 | "OR", 119 | "ORDER", 120 | "OUTER", 121 | "OVER", 122 | "PERCENT", 123 | "PIVOT", 124 | "PLAN", 125 | "PRECISION", 126 | "PRIMARY", 127 | "PRINT", 128 | "PROC", 129 | "PROCEDURE", 130 | "PUBLIC", 131 | "RAISERROR", 132 | "READ", 133 | "READTEXT", 134 | "RECONFIGURE", 135 | "REFERENCES", 136 | "REPLICATION", 137 | "RESTORE", 138 | "RESTRICT", 139 | "RETURN", 140 | "REVERT", 141 | "REVOKE", 142 | "RIGHT", 143 | "ROLLBACK", 144 | "ROWCOUNT", 145 | "ROWGUIDCOL", 146 | "RULE", 147 | "SAVE", 148 | "SCHEMA", 149 | "SECURITYAUDIT", 150 | "SELECT", 151 | "SEMANTICKEYPHRASETABLE", 152 | "SEMANTICSIMILARITYDETAILSTABLE", 153 | "SEMANTICSIMILARITYTABLE", 154 | "SESSION_USER", 155 | "SET", 156 | "SETUSER", 157 | "SHUTDOWN", 158 | "SOME", 159 | "STATISTICS", 160 | "SYSTEM_USER", 161 | "TABLE", 162 | "TABLESAMPLE", 163 | "TEXTSIZE", 164 | "THEN", 165 | "TO", 166 | "TOP", 167 | "TRAN", 168 | "TRANSACTION", 169 | "TRIGGER", 170 | "TRUNCATE", 171 | "TRY_CONVERT", 172 | "TSEQUAL", 173 | "UNION", 174 | "UNIQUE", 175 | "UNPIVOT", 176 | "UPDATE", 177 | "UPDATETEXT", 178 | "USE", 179 | "USER", 180 | "VALUES", 181 | "VARYING", 182 | "VIEW", 183 | "WAITFOR", 184 | "WHEN", 185 | "WHERE", 186 | "WHILE", 187 | "WITH", 188 | "WITHIN GROUP", 189 | "WRITETEXT", 190 | } 191 | -------------------------------------------------------------------------------- /dialect/oracle.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | var oracleReservedWords = []string{ 4 | "A", "ADD", "ACCESSIBLE", "AGENT", "AGGREGATE", "ARRAY", "ATTRIBUTE", "AUTHID", "AVG", 5 | "BFILE_BASE", "BINARY", "BLOB_BASE", "BLOCK", "BODY", "BOTH", "BOUND", "BULK", "BYTE", 6 | "C", "CALL", "CALLING", "CASCADE", "CHAR", "CHAR_BASE", "CHARACTER", "CHARSET", "CHARSETFORM", "CHARSETID", "CLOB_BASE", "CLONE", "CLOSE", "COLLECT", "COMMENT", "COMMIT", "COMMITTED", "COMPILED", "CONSTANT", "CONSTRUCTOR", "CONTEXT", "CONTINUE", "CONVERT", "COUNT", "CREDENTIAL", "CURRENT", "CUSTOMDATUM", 7 | "DANGLING", "DATA", "DATE", "DATE_BASE", "DAY", "DEFINE", "DELETE", "DETERMINISTIC", "DIRECTORY", "DOUBLE", "DURATION", 8 | "ELEMENT", "ELSIF", "EMPTY", "ESCAPE", "EXCEPT", "EXCEPTIONS", "EXECUTE", "EXISTS", "EXIT", "EXTERNAL", 9 | "FINAL", "FIRST", "FIXED", "FLOAT", "FROM", "FORALL", "FORCE", 10 | "GENERAL", 11 | "HASH", "HEAP", "HIDDEN", "HOUR", 12 | "IMMEDIATE", "IMMUTABLE", "INCLUDING", "INDICATOR", "INDICES", "INFINITE", "INSTANTIABLE", "INT", "INTERFACE", "INTERVAL", "INVALIDATE", "ISOLATION", 13 | "JAVA", 14 | "LANGUAGE", "LARGE", "LEADING", "LENGTH", "LEVEL", "LIBRARY", "LIKE2", "LIKE4", "LIKEC", "LIMIT", "LIMITED", "LOCAL", "LONG", "LOOP", 15 | "MAP", "MAX", "MAXLEN", "MEMBER", "MERGE", "MIN", "MINUTE", "MOD", "MODIFY", "MONTH", "MULTISET", "MUTABLE", 16 | "NAME", "NAN", "NATIONAL", "NATIVE", "NCHAR", "NEW", "NOCOPY", "NUMBER_BASE", 17 | "OBJECT", "OCICOLL", "OCIDATE", "OCIDATETIME", "OCIDURATION", "OCIINTERVAL", "OCILOBLOCATOR", "OCINUMBER", "OCIRAW", "OCIREF", "OCIREFCURSOR", "OCIROWID", "OCISTRING", "OCITYPE", "OLD", "ONLY", "OPAQUE", "OPEN", "OPERATOR", "ORACLE", "ORADATA", "ORGANIZATION", "ORLANY", "ORLVARY", "OTHERS", "OUT", "OVERRIDING", 18 | "PACKAGE", "PARALLEL_ENABLE", "PARAMETER", "PARAMETERS", "PARENT", "PARTITION", "PASCAL", "PERSISTABLE", "PIPE", "PIPELINED", "PLUGGABLE", "POLYMORPHIC", "PRAGMA", "PRECISION", "PRIOR", "PRIVATE", 19 | "RAISE", "RANGE", "RAW", "READ", "RECORD", "REF", "REFERENCE", "RELIES_ON", "REM", "REMAINDER", "RENAME", "RESULT", "RESULT_CACHE", "RETURN", "RETURNING", "REVERSE", "ROLLBACK", "ROW", 20 | "SAMPLE", "SAVE", "SAVEPOINT", "SB1", "SB2", "SB4", "SECOND", "SEGMENT", "SELF", "SEPARATE", "SEQUENCE", "SERIALIZABLE", "SET", "SHORT", "SIZE_T", "SOME", "SPARSE", "SQLCODE", "SQLDATA", "SQLNAME", "SQLSTATE", "STANDARD", "STATIC", "STDDEV", "STORED", "STRING", "STRUCT", "STYLE", "SUBMULTISET", "SUBPARTITION", "SUBSTITUTABLE", "SUM", "SYNONYM", 21 | "TDO", "THE", "TIME", "TIMESTAMP", "TIMEZONE_ABBR", "TIMEZONE_HOUR", "TIMEZONE_MINUTE", "TIMEZONE_REGION", "TRAILING", "TRANSACTION", "TRANSACTIONAL", "TRUSTED", 22 | "UB1", "UB2", "UB4", "UNDER", "UNPLUG", "UNSIGNED", "UNTRUSTED", "USE", "USING", 23 | "VALIST", "VALUE", "VARIABLE", "VARIANCE", "VARRAY", "VARYING", "VOID", 24 | "WHILE", "WORK", "WRAPPED", "WRITE", 25 | "YEAR", 26 | "ZONE", 27 | } 28 | 29 | var oracleKeyWords = []string{ 30 | "ALL", "ALTER", "AND", "ANY", "AS", "ASC", "AT", 31 | "BEGIN", "BETWEEN", "BY", 32 | "CASE", "CHECK", "CLUSTERS", "CLUSTER", "COLAUTH", "COLUMNS", "COMPRESS", "CONNECT", "CRASH", "CREATE", "CURSOR", 33 | "DECLARE", "DEFAULT", "DESC", "DISTINCT", "DROP", 34 | "ELSE", "END", "EXCEPTION", "EXCLUSIVE", 35 | "FETCH", "FOR", "FROM", "FUNCTION", 36 | "GOTO", "GRANT", "GROUP", 37 | "HAVING", 38 | "IDENTIFIED", "IF", "IN", "INDEX", "INDEXES", "INSERT", "INTERSECT", "INTO", "IS", 39 | "LIKE", "LOCK", 40 | "MINUS", "MODE", 41 | "NOCOMPRESS", "NOT", "NOWAIT", "NULL", 42 | "OF", "ON", "OPTION", "OR", "ORDER", "OVERLAPS", 43 | "PROCEDURE", "PUBLIC", 44 | "RESOURCE", "REVOKE", 45 | "SELECT", "SHARE", "SIZE", "SQL", "START", "SUBTYPE", 46 | "TABAUTH", "TABLE", "THEN", "TO", "TYPE", 47 | "UNION", "UNIQUE", "UPDATE", 48 | "VALUES", "VIEW", "VIEWS", 49 | "WHEN", "WHERE", "WITH", 50 | } 51 | -------------------------------------------------------------------------------- /dialect/sqlite.go: -------------------------------------------------------------------------------- 1 | package dialect 2 | 3 | var sqliteKeywords = []string{ 4 | "ABORT", 5 | "ACTION", 6 | "ADD", 7 | "AFTER", 8 | "ALL", 9 | "ALTER", 10 | "ALWAYS", 11 | "ANALYZE", 12 | "AND", 13 | "AS", 14 | "ASC", 15 | "ATTACH", 16 | "AUTOINCREMENT", 17 | "BEFORE", 18 | "BEGIN", 19 | "BETWEEN", 20 | "BY", 21 | "CASCADE", 22 | "CASE", 23 | "CAST", 24 | "CHECK", 25 | "COLLATE", 26 | "COLUMN", 27 | "COMMIT", 28 | "CONFLICT", 29 | "CONSTRAINT", 30 | "CREATE", 31 | "CROSS", 32 | "CURRENT", 33 | "CURRENT_DATE", 34 | "CURRENT_TIME", 35 | "CURRENT_TIMESTAMP", 36 | "DATABASE", 37 | "DEFAULT", 38 | "DEFERRABLE", 39 | "DEFERRED", 40 | "DELETE", 41 | "DESC", 42 | "DETACH", 43 | "DISTINCT", 44 | "DO", 45 | "DROP", 46 | "EACH", 47 | "ELSE", 48 | "END", 49 | "ESCAPE", 50 | "EXCEPT", 51 | "EXCLUDE", 52 | "EXCLUSIVE", 53 | "EXISTS", 54 | "EXPLAIN", 55 | "FAIL", 56 | "FILTER", 57 | "FIRST", 58 | "FOLLOWING", 59 | "FOR", 60 | "FOREIGN", 61 | "FROM", 62 | "FULL", 63 | "GENERATED", 64 | "GLOB", 65 | "GROUP", 66 | "GROUPS", 67 | "HAVING", 68 | "IF", 69 | "IGNORE", 70 | "IMMEDIATE", 71 | "IN", 72 | "INDEX", 73 | "INDEXED", 74 | "INITIALLY", 75 | "INNER", 76 | "INSERT", 77 | "INSTEAD", 78 | "INTERSECT", 79 | "INTO", 80 | "IS", 81 | "ISNULL", 82 | "JOIN", 83 | "KEY", 84 | "LAST", 85 | "LEFT", 86 | "LIKE", 87 | "LIMIT", 88 | "MATCH", 89 | "NATURAL", 90 | "NO", 91 | "NOT", 92 | "NOTHING", 93 | "NOTNULL", 94 | "NULL", 95 | "NULLS", 96 | "OF", 97 | "OFFSET", 98 | "ON", 99 | "OR", 100 | "ORDER", 101 | "OTHERS", 102 | "OUTER", 103 | "OVER", 104 | "PARTITION", 105 | "PLAN", 106 | "PRAGMA", 107 | "PRECEDING", 108 | "PRIMARY", 109 | "QUERY", 110 | "RAISE", 111 | "RANGE", 112 | "RECURSIVE", 113 | "REFERENCES", 114 | "REGEXP", 115 | "REINDEX", 116 | "RELEASE", 117 | "RENAME", 118 | "REPLACE", 119 | "RESTRICT", 120 | "RIGHT", 121 | "ROLLBACK", 122 | "ROW", 123 | "ROWS", 124 | "SAVEPOINT", 125 | "SELECT", 126 | "SET", 127 | "TABLE", 128 | "TEMP", 129 | "TEMPORARY", 130 | "THEN", 131 | "TIES", 132 | "TO", 133 | "TRANSACTION", 134 | "TRIGGER", 135 | "UNBOUNDED", 136 | "UNION", 137 | "UNIQUE", 138 | "UPDATE", 139 | "USING", 140 | "VACUUM", 141 | "VALUES", 142 | "VIEW", 143 | "VIRTUAL", 144 | "WHEN", 145 | "WHERE", 146 | "WINDOW", 147 | "WITH", 148 | "WITHOUT", 149 | } 150 | -------------------------------------------------------------------------------- /doc/develop.md: -------------------------------------------------------------------------------- 1 | ## Start test databases 2 | 3 | ```sh 4 | docker-compose up -d 5 | ``` 6 | 7 | ## MySQL setup 8 | 9 | ```sh 10 | wget https://downloads.mysql.com/docs/world.sql.gz 11 | gzip world.sql.gz 12 | 13 | # MySQL 5.6 14 | mysql -u root -proot -h 127.0.0.1 -P 13305 < world.sql 15 | # MySQL 5.7 16 | mysql -u root -proot -h 127.0.0.1 -P 13306 < world.sql 17 | # MySQL 8 18 | mysql -u root -proot -h 127.0.0.1 -P 13307 < world.sql 19 | rm world.sql 20 | ``` 21 | 22 | ## Export keyword & function list 23 | 24 | ```sh 25 | mysql -u root -proot -h 127.0.0.1 -P 13305 -D mysql < help_categories.sql > ./export/help_categories_mysql56.txt 26 | mysql -u root -proot -h 127.0.0.1 -P 13306 -D mysql < help_categories.sql > ./export/help_categories_mysql57.txt 27 | mysql -u root -proot -h 127.0.0.1 -P 13307 -D mysql < help_categories.sql > ./export/help_categories_mysql8.txt 28 | # Export keyword list 29 | mysql -u root -proot -h 127.0.0.1 -P 13305 -D mysql < help_keywords_mysql56.sql > ./export/help_keywords_mysql56.txt 30 | mysql -u root -proot -h 127.0.0.1 -P 13306 -D mysql < help_keywords_mysql57.sql > ./export/help_keywords_mysql57.txt 31 | mysql -u root -proot -h 127.0.0.1 -P 13307 -D mysql < help_keywords_mysql8.sql > ./export/help_keywords_mysql8.txt 32 | # Export function list 33 | mysql -u root -proot -h 127.0.0.1 -P 13305 -D mysql < help_functions_mysql56.sql > ./export/help_functions_mysql56.txt 34 | mysql -u root -proot -h 127.0.0.1 -P 13306 -D mysql < help_functions_mysql57.sql > ./export/help_functions_mysql57.txt 35 | mysql -u root -proot -h 127.0.0.1 -P 13307 -D mysql < help_functions_mysql8.sql > ./export/help_functions_mysql8.txt 36 | ``` 37 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | services: 3 | mysql8: 4 | image: mysql:8 5 | container_name: sqls_mysql8 6 | environment: 7 | MYSQL_ROOT_PASSWORD: root 8 | MYSQL_DATABASE: world 9 | MYSQL_USER: docker 10 | MYSQL_PASSWORD: docker 11 | TZ: 'Asia/Tokyo' 12 | command: mysqld --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci 13 | volumes: 14 | - ./.docker/mysql/data:/var/lib/mysql8 15 | ports: 16 | - 13307:3306 17 | mysql57: 18 | image: mysql:5.7 19 | container_name: sqls_mysql57 20 | environment: 21 | MYSQL_ROOT_PASSWORD: root 22 | MYSQL_DATABASE: world 23 | MYSQL_USER: docker 24 | MYSQL_PASSWORD: docker 25 | TZ: 'Asia/Tokyo' 26 | command: mysqld --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci 27 | volumes: 28 | - ./.docker/mysql/data:/var/lib/mysql57 29 | ports: 30 | - 13306:3306 31 | mysql56: 32 | image: mysql:5.6 33 | container_name: sqls_mysql56 34 | environment: 35 | MYSQL_ROOT_PASSWORD: root 36 | MYSQL_DATABASE: world 37 | MYSQL_USER: docker 38 | MYSQL_PASSWORD: docker 39 | TZ: 'Asia/Tokyo' 40 | command: mysqld --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci 41 | volumes: 42 | - ./.docker/mysql/data:/var/lib/mysql56 43 | ports: 44 | - 13305:3306 45 | postgres12: 46 | image: postgres:12-alpine 47 | container_name: sqls_postgres12 48 | ports: 49 | - "15432:5432" 50 | environment: 51 | - POSTGRES_USER=postgres 52 | - POSTGRES_PASSWORD=mysecretpassword1234 53 | - PGPASSWORD=mysecretpassword1234 54 | - POSTGRES_DB=dvdrental 55 | - DATABASE_HOST=localhost 56 | volumes: 57 | - ./.docker/postgres/data:/var/lib/postgresql/data 58 | mssql2019: 59 | image: mcr.microsoft.com/mssql/server:2019-latest 60 | container_name: sqls_mssql2019 61 | ports: 62 | - "11433:1433" 63 | environment: 64 | SA_PASSWORD: Passw0rd 65 | ACCEPT_EULA: Y 66 | volumes: 67 | - ./.docker/mssql/data:/var/opt/mssql/data 68 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/sqls-server/sqls 2 | 3 | go 1.21 4 | 5 | require ( 6 | github.com/CodinGame/h2go v0.6.1 7 | github.com/denisenkom/go-mssqldb v0.12.3 8 | github.com/go-sql-driver/mysql v1.7.1 9 | github.com/godror/godror v0.41.0 10 | github.com/google/go-cmp v0.5.9 11 | github.com/jackc/pgx/v4 v4.18.1 12 | github.com/jfcote87/sshdb v0.5.3 13 | github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88 // indirect 14 | github.com/sourcegraph/jsonrpc2 v0.2.0 15 | github.com/urfave/cli/v2 v2.27.0 16 | golang.org/x/crypto v0.17.0 17 | gopkg.in/yaml.v2 v2.4.0 18 | ) 19 | 20 | require ( 21 | github.com/k0kubun/pp v3.0.1+incompatible 22 | github.com/mattn/go-sqlite3 v1.14.19 23 | github.com/olekukonko/tablewriter v0.0.5 24 | github.com/vertica/vertica-sql-go v1.3.3 25 | ) 26 | 27 | require ( 28 | github.com/ClickHouse/ch-go v0.58.2 // indirect 29 | github.com/ClickHouse/clickhouse-go/v2 v2.17.1 // indirect 30 | github.com/andybalholm/brotli v1.0.6 // indirect 31 | github.com/cpuguy83/go-md2man/v2 v2.0.3 // indirect 32 | github.com/elastic/go-sysinfo v1.11.2 // indirect 33 | github.com/elastic/go-windows v1.0.1 // indirect 34 | github.com/go-faster/city v1.0.1 // indirect 35 | github.com/go-faster/errors v0.6.1 // indirect 36 | github.com/go-logfmt/logfmt v0.6.0 // indirect 37 | github.com/godror/knownpb v0.1.1 // indirect 38 | github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect 39 | github.com/golang-sql/sqlexp v0.1.0 // indirect 40 | github.com/google/uuid v1.5.0 // indirect 41 | github.com/jackc/chunkreader/v2 v2.0.1 // indirect 42 | github.com/jackc/pgconn v1.14.1 // indirect 43 | github.com/jackc/pgio v1.0.0 // indirect 44 | github.com/jackc/pgpassfile v1.0.0 // indirect 45 | github.com/jackc/pgproto3/v2 v2.3.2 // indirect 46 | github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect 47 | github.com/jackc/pgtype v1.14.0 // indirect 48 | github.com/joeshaw/multierror v0.0.0-20140124173710-69b34d4ec901 // indirect 49 | github.com/klauspost/compress v1.16.7 // indirect 50 | github.com/mattn/go-colorable v0.1.6 // indirect 51 | github.com/mattn/go-isatty v0.0.12 // indirect 52 | github.com/mattn/go-runewidth v0.0.15 // indirect 53 | github.com/paulmach/orb v0.10.0 // indirect 54 | github.com/pierrec/lz4/v4 v4.1.18 // indirect 55 | github.com/pkg/errors v0.9.1 // indirect 56 | github.com/prometheus/procfs v0.12.0 // indirect 57 | github.com/rivo/uniseg v0.4.4 // indirect 58 | github.com/russross/blackfriday/v2 v2.1.0 // indirect 59 | github.com/segmentio/asm v1.2.0 // indirect 60 | github.com/shopspring/decimal v1.3.1 // indirect 61 | github.com/sirupsen/logrus v1.9.3 // indirect 62 | github.com/xrash/smetrics v0.0.0-20231213231151-1d8dd44e695e // indirect 63 | go.opentelemetry.io/otel v1.19.0 // indirect 64 | go.opentelemetry.io/otel/trace v1.19.0 // indirect 65 | golang.org/x/exp v0.0.0-20231226003508-02704c960a9b // indirect 66 | golang.org/x/sys v0.15.0 // indirect 67 | golang.org/x/text v0.14.0 // indirect 68 | google.golang.org/protobuf v1.32.0 // indirect 69 | gopkg.in/yaml.v3 v3.0.1 // indirect 70 | howett.net/plist v1.0.1 // indirect 71 | ) 72 | -------------------------------------------------------------------------------- /imgs/sqls-completion.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sqls-server/sqls/efe7f66d16e9479e242d3876c2a4a878ee190568/imgs/sqls-completion.gif -------------------------------------------------------------------------------- /imgs/sqls-fk_joins.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sqls-server/sqls/efe7f66d16e9479e242d3876c2a4a878ee190568/imgs/sqls-fk_joins.gif -------------------------------------------------------------------------------- /imgs/sqls_document_format.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sqls-server/sqls/efe7f66d16e9479e242d3876c2a4a878ee190568/imgs/sqls_document_format.gif -------------------------------------------------------------------------------- /imgs/sqls_hover.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sqls-server/sqls/efe7f66d16e9479e242d3876c2a4a878ee190568/imgs/sqls_hover.gif -------------------------------------------------------------------------------- /imgs/sqls_signature_help.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sqls-server/sqls/efe7f66d16e9479e242d3876c2a4a878ee190568/imgs/sqls_signature_help.gif -------------------------------------------------------------------------------- /internal/completer/completer_test.go: -------------------------------------------------------------------------------- 1 | package completer 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/sqls-server/sqls/internal/lsp" 8 | ) 9 | 10 | func TestGetBeforeCursorText(t *testing.T) { 11 | input := `SELECT 12 | a, b, c 13 | FROM 14 | hogetable 15 | ` 16 | tests := []struct { 17 | in string 18 | line int 19 | char int 20 | out string 21 | }{ 22 | {input, 1, 2, "SE"}, 23 | {input, 2, 3, "SELECT\na, "}, 24 | {input, 3, 4, "SELECT\na, b, c\nFROM"}, 25 | {input, 4, 5, "SELECT\na, b, c\nFROM\nhoget"}, 26 | } 27 | for _, tt := range tests { 28 | got := getBeforeCursorText(tt.in, tt.line, tt.char) 29 | if tt.out != got { 30 | t.Errorf("want %#v, got %#v", tt.out, got) 31 | } 32 | } 33 | } 34 | 35 | func TestGetLastWord(t *testing.T) { 36 | input := `SELECT 37 | a, b, c 38 | FROM 39 | hogetable 40 | ` 41 | tests := []struct { 42 | name string 43 | in string 44 | line int 45 | char int 46 | out string 47 | }{ 48 | {"", "SELECT FROM def", 1, 7, ""}, 49 | {"", input, 1, 2, "SE"}, 50 | {"", input, 2, 3, ""}, 51 | {"", input, 3, 4, "FROM"}, 52 | {"", input, 3, 6, ""}, 53 | {"", input, 4, 5, "h"}, 54 | {"", "`ident", 1, 6, "`ident"}, 55 | {"", "parent.`ident", 1, 13, "`ident"}, 56 | {"", "`parent`.`ident", 1, 15, "`ident"}, 57 | } 58 | for _, tt := range tests { 59 | t.Run(tt.name, func(t *testing.T) { 60 | got := getLastWord(tt.in, tt.line, tt.char) 61 | if tt.out != got { 62 | t.Errorf("want %#v, got %#v", tt.out, got) 63 | } 64 | }) 65 | } 66 | } 67 | 68 | func Test_completionTypeIs(t *testing.T) { 69 | type args struct { 70 | } 71 | tests := []struct { 72 | name string 73 | completionTypes []completionType 74 | expect completionType 75 | want bool 76 | }{ 77 | { 78 | completionTypes: []completionType{ 79 | CompletionTypeColumn, 80 | }, 81 | expect: CompletionTypeColumn, 82 | want: true, 83 | }, 84 | { 85 | completionTypes: []completionType{ 86 | CompletionTypeTable, 87 | CompletionTypeView, 88 | CompletionTypeFunction, 89 | CompletionTypeColumn, 90 | }, 91 | expect: CompletionTypeColumn, 92 | want: true, 93 | }, 94 | { 95 | completionTypes: []completionType{ 96 | CompletionTypeTable, 97 | CompletionTypeView, 98 | CompletionTypeFunction, 99 | }, 100 | expect: CompletionTypeColumn, 101 | want: false, 102 | }, 103 | } 104 | for _, tt := range tests { 105 | t.Run(tt.name, func(t *testing.T) { 106 | if got := completionTypeIs(tt.completionTypes, tt.expect); got != tt.want { 107 | t.Errorf("completionTypeIs() = %v, want %v", got, tt.want) 108 | } 109 | }) 110 | } 111 | } 112 | 113 | func TestComplete(t *testing.T) { 114 | tests := []struct { 115 | name string 116 | text string 117 | lowerCase bool 118 | expected []lsp.CompletionItem 119 | }{ 120 | { 121 | name: "keyword", 122 | text: "sel", 123 | expected: []lsp.CompletionItem{ 124 | { 125 | Label: "SELECT", 126 | Kind: lsp.KeywordCompletion, 127 | Detail: "keyword", 128 | SortText: "9999SELECT", 129 | }, 130 | }, 131 | }, 132 | { 133 | name: "keyword-lowercase", 134 | text: "sel", 135 | lowerCase: true, 136 | expected: []lsp.CompletionItem{ 137 | { 138 | Label: "select", 139 | Kind: lsp.KeywordCompletion, 140 | Detail: "keyword", 141 | SortText: "9999select", 142 | }, 143 | }, 144 | }, 145 | } 146 | 147 | for _, tt := range tests { 148 | t.Run("", func(t *testing.T) { 149 | c := NewCompleter(nil) 150 | got, err := c.Complete("sel", lsp.CompletionParams{ 151 | TextDocumentPositionParams: lsp.TextDocumentPositionParams{ 152 | Position: lsp.Position{ 153 | Line: 0, 154 | Character: len(tt.text), 155 | }, 156 | }, 157 | }, tt.lowerCase) 158 | if err != nil { 159 | t.Fatal(err) 160 | } 161 | 162 | if !reflect.DeepEqual(got, tt.expected) { 163 | t.Errorf("\nwant: %v\ngot: %v", tt.expected, got) 164 | } 165 | }) 166 | } 167 | } 168 | 169 | func TestGenerateAlias(t *testing.T) { 170 | noMatchesTable := make(map[string]interface{}) 171 | noMatchesTable["XX"] = true 172 | matchesTable := make(map[string]interface{}) 173 | matchesTable["XX"] = true 174 | matchesTable["T1"] = true 175 | 176 | tests := []struct { 177 | name string 178 | table string 179 | tMap map[string]interface{} 180 | want string 181 | }{ 182 | { 183 | "no matches", 184 | "Table", 185 | noMatchesTable, 186 | "T1", 187 | }, 188 | { 189 | "matches", 190 | "Table", 191 | matchesTable, 192 | "T2", 193 | }, 194 | } 195 | for _, tt := range tests { 196 | t.Run(tt.name, func(t *testing.T) { 197 | if got := generateTableAlias(tt.table, tt.tMap); got != tt.want { 198 | t.Errorf("generateAlias() = %v, want %v", got, tt.want) 199 | } 200 | }) 201 | } 202 | } 203 | -------------------------------------------------------------------------------- /internal/config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "os" 8 | "path/filepath" 9 | 10 | "github.com/sqls-server/sqls/internal/database" 11 | "gopkg.in/yaml.v2" 12 | ) 13 | 14 | var ( 15 | ErrNotFoundConfig = errors.New("NotFound Config") 16 | ) 17 | 18 | var ( 19 | YamlConfigPath = configFilePath("config.yml") 20 | ) 21 | 22 | type Config struct { 23 | LowercaseKeywords bool `json:"lowercaseKeywords" yaml:"lowercaseKeywords"` 24 | Connections []*database.DBConfig `json:"connections" yaml:"connections"` 25 | } 26 | 27 | func (c *Config) Validate() error { 28 | if len(c.Connections) > 0 { 29 | return c.Connections[0].Validate() 30 | } 31 | return nil 32 | } 33 | 34 | func NewConfig() *Config { 35 | cfg := &Config{} 36 | cfg.LowercaseKeywords = false 37 | return cfg 38 | } 39 | 40 | func GetDefaultConfig() (*Config, error) { 41 | cfg := NewConfig() 42 | if err := cfg.Load(YamlConfigPath); err != nil { 43 | return nil, err 44 | } 45 | return cfg, nil 46 | } 47 | 48 | func GetConfig(fp string) (*Config, error) { 49 | cfg := NewConfig() 50 | expandPath, err := expand(fp) 51 | if err != nil { 52 | return nil, err 53 | } 54 | if err := cfg.Load(expandPath); err != nil { 55 | return nil, err 56 | } 57 | return cfg, nil 58 | } 59 | 60 | func (c *Config) Load(fp string) error { 61 | if !IsFileExist(fp) { 62 | return ErrNotFoundConfig 63 | } 64 | 65 | file, err := os.OpenFile(fp, os.O_RDONLY, 0666) 66 | if err != nil { 67 | return fmt.Errorf("cannot open config, %w", err) 68 | } 69 | defer file.Close() 70 | 71 | b, err := io.ReadAll(file) 72 | if err != nil { 73 | return fmt.Errorf("cannot read config, %w", err) 74 | } 75 | 76 | if err = yaml.Unmarshal(b, c); err != nil { 77 | return fmt.Errorf("failed unmarshal yaml, %w, %s", err, string(b)) 78 | } 79 | 80 | if err := c.Validate(); err != nil { 81 | return fmt.Errorf("failed validation, %w", err) 82 | } 83 | return nil 84 | } 85 | 86 | func IsFileExist(fPath string) bool { 87 | _, err := os.Stat(fPath) 88 | return err == nil || !os.IsNotExist(err) 89 | } 90 | 91 | func configFilePath(fileName string) string { 92 | if xdgConfigHome := os.Getenv("XDG_CONFIG_HOME"); xdgConfigHome != "" { 93 | return filepath.Join(xdgConfigHome, "sqls", fileName) 94 | } 95 | 96 | homeDir, err := os.UserHomeDir() 97 | if err != nil { 98 | panic(err) 99 | } 100 | 101 | return filepath.Join(homeDir, ".config", "sqls", fileName) 102 | } 103 | 104 | func expand(path string) (string, error) { 105 | if len(path) == 0 || path[0] != '~' { 106 | return path, nil 107 | } 108 | 109 | homeDir, err := os.UserHomeDir() 110 | if err != nil { 111 | return "", err 112 | } 113 | return filepath.Join(homeDir, path[1:]), nil 114 | } 115 | -------------------------------------------------------------------------------- /internal/config/config_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/google/go-cmp/cmp" 9 | "github.com/sqls-server/sqls/internal/database" 10 | ) 11 | 12 | func TestGetConfig(t *testing.T) { 13 | type args struct { 14 | fp string 15 | } 16 | tests := []struct { 17 | name string 18 | args args 19 | want *Config 20 | wantErr bool 21 | errMsg string 22 | }{ 23 | { 24 | name: "basic", 25 | args: args{ 26 | fp: "basic.yml", 27 | }, 28 | want: &Config{ 29 | LowercaseKeywords: true, 30 | Connections: []*database.DBConfig{ 31 | { 32 | Alias: "sqls_mysql", 33 | Driver: "mysql", 34 | Proto: "tcp", 35 | User: "root", 36 | Passwd: "root", 37 | Host: "127.0.0.1", 38 | Port: 13306, 39 | DBName: "world", 40 | Params: map[string]string{"autocommit": "true", "tls": "skip-verify"}, 41 | }, 42 | { 43 | Alias: "sqls_sqlite3", 44 | Driver: "sqlite3", 45 | DataSourceName: "file:/home/sqls-server/chinook.db", 46 | }, 47 | { 48 | Alias: "sqls_postgresql", 49 | Driver: "postgresql", 50 | Proto: "tcp", 51 | User: "postgres", 52 | Passwd: "mysecretpassword1234", 53 | Host: "127.0.0.1", 54 | Port: 15432, 55 | DBName: "dvdrental", 56 | Params: map[string]string{"sslmode": "disable"}, 57 | }, 58 | { 59 | Alias: "mysql_with_bastion", 60 | Driver: "mysql", 61 | Proto: "tcp", 62 | User: "admin", 63 | Passwd: "Q+ACgv12ABx/", 64 | Host: "192.168.121.163", 65 | Port: 3306, 66 | DBName: "world", 67 | SSHCfg: &database.SSHConfig{ 68 | Host: "192.168.121.168", 69 | Port: 22, 70 | User: "vagrant", 71 | PassPhrase: "passphrase1234", 72 | PrivateKey: "/home/sqls-server/.ssh/id_rsa", 73 | }, 74 | }, 75 | }, 76 | }, 77 | wantErr: false, 78 | }, 79 | { 80 | name: "no driver", 81 | args: args{ 82 | fp: "no_driver.yml", 83 | }, 84 | want: nil, 85 | wantErr: true, 86 | errMsg: "failed validation, required: connections[].driver", 87 | }, 88 | { 89 | name: "no connection", 90 | args: args{ 91 | fp: "no_connection.yml", 92 | }, 93 | want: nil, 94 | wantErr: true, 95 | errMsg: "failed validation, required: connections[].dataSourceName or connections[].proto", 96 | }, 97 | { 98 | name: "no user", 99 | args: args{ 100 | fp: "no_user.yml", 101 | }, 102 | want: nil, 103 | wantErr: true, 104 | errMsg: "failed validation, required: connections[].user", 105 | }, 106 | { 107 | name: "invalid proto", 108 | args: args{ 109 | fp: "invalid_proto.yml", 110 | }, 111 | want: nil, 112 | wantErr: true, 113 | errMsg: "failed validation, invalid: connections[].proto", 114 | }, 115 | { 116 | name: "no path", 117 | args: args{ 118 | fp: "no_path.yml", 119 | }, 120 | want: nil, 121 | wantErr: true, 122 | errMsg: "failed validation, required: connections[].path", 123 | }, 124 | { 125 | name: "no dsn", 126 | args: args{ 127 | fp: "no_dsn.yml", 128 | }, 129 | want: &Config{ 130 | Connections: []*database.DBConfig{ 131 | { 132 | Alias: "sqls_sqlite3", 133 | Driver: "sqlite3", 134 | DataSourceName: "", 135 | }, 136 | }, 137 | }, 138 | wantErr: true, 139 | errMsg: "failed validation, required: connections[].dataSourceName", 140 | }, 141 | { 142 | name: "no ssh host", 143 | args: args{ 144 | fp: "no_ssh_host.yml", 145 | }, 146 | want: nil, 147 | wantErr: true, 148 | errMsg: "failed validation, required: connections[]sshConfig.host", 149 | }, 150 | { 151 | name: "no ssh user", 152 | args: args{ 153 | fp: "no_ssh_user.yml", 154 | }, 155 | want: nil, 156 | wantErr: true, 157 | errMsg: "failed validation, required: connections[].sshConfig.user", 158 | }, 159 | { 160 | name: "no ssh private key", 161 | args: args{ 162 | fp: "no_ssh_private_key.yml", 163 | }, 164 | want: nil, 165 | wantErr: true, 166 | errMsg: "failed validation, required: connections[].sshConfig.privateKey", 167 | }, 168 | { 169 | name: "oracle config", 170 | args: args{ 171 | fp: "oracle.yaml", 172 | }, 173 | want: &Config{ 174 | Connections: []*database.DBConfig{ 175 | { 176 | Alias: "TestDB", 177 | Driver: "oracle", 178 | DataSourceName: "SYSTEM/P1ssword@localhost:1521/XE", 179 | }, 180 | }, 181 | }, 182 | wantErr: true, 183 | errMsg: "failed validation, required: connections[].sshConfig.privateKey", 184 | }, 185 | } 186 | for _, tt := range tests { 187 | packageDir, err := os.Getwd() 188 | if err != nil { 189 | t.Fatalf("cannot get package path, Err=%v", err) 190 | } 191 | testFile := filepath.Join(packageDir, "testdata", tt.args.fp) 192 | 193 | t.Run(tt.name, func(t *testing.T) { 194 | got, err := GetConfig(testFile) 195 | if err != nil { 196 | if tt.wantErr { 197 | if err.Error() != tt.errMsg { 198 | t.Errorf("unmatch error message, want:%q got:%q", tt.errMsg, err.Error()) 199 | } 200 | } else { 201 | t.Errorf("GetConfig() error = %v, wantErr %v", err, tt.wantErr) 202 | return 203 | } 204 | } 205 | if diff := cmp.Diff(tt.want, got); diff != "" { 206 | t.Errorf("unmatch (- want, + got):\n%s", diff) 207 | } 208 | }) 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /internal/config/testdata/basic.yml: -------------------------------------------------------------------------------- 1 | lowercaseKeywords: true 2 | connections: 3 | - alias: sqls_mysql 4 | driver: mysql 5 | dataSourceName: "" 6 | proto: tcp 7 | user: root 8 | passwd: root 9 | host: 127.0.0.1 10 | port: 13306 11 | path: "" 12 | dbName: world 13 | params: 14 | autocommit: "true" 15 | tls: skip-verify 16 | - alias: sqls_sqlite3 17 | driver: sqlite3 18 | dataSourceName: "file:/home/sqls-server/chinook.db" 19 | - alias: sqls_postgresql 20 | driver: postgresql 21 | dataSourceName: "" 22 | proto: tcp 23 | user: postgres 24 | passwd: mysecretpassword1234 25 | host: 127.0.0.1 26 | port: 15432 27 | path: "" 28 | dbName: dvdrental 29 | params: 30 | sslmode: disable 31 | - alias: mysql_with_bastion 32 | driver: mysql 33 | dataSourceName: "" 34 | proto: tcp 35 | user: admin 36 | passwd: Q+ACgv12ABx/ 37 | host: 192.168.121.163 38 | port: 3306 39 | dbName: world 40 | sshConfig: 41 | host: 192.168.121.168 42 | port: 22 43 | user: vagrant 44 | passPhrase: passphrase1234 45 | privateKey: /home/sqls-server/.ssh/id_rsa 46 | -------------------------------------------------------------------------------- /internal/config/testdata/invalid_proto.yml: -------------------------------------------------------------------------------- 1 | connections: 2 | - alias: sqls_mysql 3 | driver: mysql 4 | dataSourceName: "" 5 | proto: invalid 6 | user: root 7 | passwd: root 8 | host: 127.0.0.1 9 | port: 13306 10 | path: "" 11 | dbName: world 12 | params: 13 | autocommit: "true" 14 | tls: skip-verify 15 | -------------------------------------------------------------------------------- /internal/config/testdata/no_connection.yml: -------------------------------------------------------------------------------- 1 | connections: 2 | - alias: sqls_mysql 3 | driver: mysql 4 | dataSourceName: "" 5 | proto: "" 6 | -------------------------------------------------------------------------------- /internal/config/testdata/no_driver.yml: -------------------------------------------------------------------------------- 1 | connections: 2 | - alias: sqls_mysql 3 | driver: "" 4 | dataSourceName: "" 5 | proto: tcp 6 | user: root 7 | passwd: root 8 | host: 127.0.0.1 9 | port: 13306 10 | path: "" 11 | dbName: world 12 | params: 13 | autocommit: "true" 14 | tls: skip-verify 15 | -------------------------------------------------------------------------------- /internal/config/testdata/no_dsn.yml: -------------------------------------------------------------------------------- 1 | connections: 2 | - alias: sqls_sqlite3 3 | driver: sqlite3 4 | dataSourceName: "" 5 | -------------------------------------------------------------------------------- /internal/config/testdata/no_host.yml: -------------------------------------------------------------------------------- 1 | connections: 2 | - alias: sqls_mysql 3 | driver: mysql 4 | dataSourceName: "" 5 | proto: unix 6 | user: root 7 | passwd: root 8 | host: "" 9 | port: 0 10 | path: "" 11 | dbName: world 12 | params: 13 | autocommit: "true" 14 | tls: skip-verify 15 | -------------------------------------------------------------------------------- /internal/config/testdata/no_path.yml: -------------------------------------------------------------------------------- 1 | connections: 2 | - alias: sqls_mysql 3 | driver: mysql 4 | dataSourceName: "" 5 | proto: unix 6 | user: root 7 | passwd: root 8 | host: 127.0.0.1 9 | port: 13306 10 | path: "" 11 | dbName: world 12 | params: 13 | autocommit: "true" 14 | tls: skip-verify 15 | -------------------------------------------------------------------------------- /internal/config/testdata/no_ssh_host.yml: -------------------------------------------------------------------------------- 1 | connections: 2 | - alias: mysql_with_bastion 3 | driver: mysql 4 | dataSourceName: "" 5 | proto: tcp 6 | user: admin 7 | passwd: Q+ACgv12ABx/ 8 | host: 192.168.121.163 9 | port: 3306 10 | dbName: world 11 | sshConfig: 12 | host: "" 13 | port: 22 14 | user: vagrant 15 | passPhrase: passphrase1234 16 | privateKey: /home/sqls-server/.ssh/id_rsa 17 | -------------------------------------------------------------------------------- /internal/config/testdata/no_ssh_private_key.yml: -------------------------------------------------------------------------------- 1 | connections: 2 | - alias: mysql_with_bastion 3 | driver: mysql 4 | dataSourceName: "" 5 | proto: tcp 6 | user: admin 7 | passwd: Q+ACgv12ABx/ 8 | host: 192.168.121.163 9 | port: 3306 10 | dbName: world 11 | sshConfig: 12 | host: 192.168.121.168 13 | port: 22 14 | user: vagrant 15 | passPhrase: passphrase1234 16 | privateKey: "" 17 | -------------------------------------------------------------------------------- /internal/config/testdata/no_ssh_user.yml: -------------------------------------------------------------------------------- 1 | connections: 2 | - alias: mysql_with_bastion 3 | driver: mysql 4 | dataSourceName: "" 5 | proto: tcp 6 | user: admin 7 | passwd: Q+ACgv12ABx/ 8 | host: 192.168.121.163 9 | port: 3306 10 | dbName: world 11 | sshConfig: 12 | host: 192.168.121.168 13 | port: 22 14 | user: "" 15 | passPhrase: passphrase1234 16 | privateKey: /home/sqls-server/.ssh/id_rsa 17 | -------------------------------------------------------------------------------- /internal/config/testdata/no_user.yml: -------------------------------------------------------------------------------- 1 | lowercaseKeywords: true 2 | connections: 3 | - alias: sqls_mysql 4 | driver: mysql 5 | dataSourceName: "" 6 | proto: tcp 7 | user: "" 8 | passwd: root 9 | host: 127.0.0.1 10 | port: 13306 11 | path: "" 12 | dbName: world 13 | params: 14 | autocommit: "true" 15 | tls: skip-verify 16 | -------------------------------------------------------------------------------- /internal/config/testdata/oracle.yaml: -------------------------------------------------------------------------------- 1 | connections: 2 | - alias: TestDB 3 | driver: oracle 4 | dataSourceName: "SYSTEM/P1ssword@localhost:1521/XE" 5 | -------------------------------------------------------------------------------- /internal/database/cache.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "context" 5 | "sort" 6 | "strings" 7 | ) 8 | 9 | type DBCacheGenerator struct { 10 | repo DBRepository 11 | } 12 | 13 | func NewDBCacheUpdater(repo DBRepository) *DBCacheGenerator { 14 | return &DBCacheGenerator{ 15 | repo: repo, 16 | } 17 | } 18 | 19 | func (u *DBCacheGenerator) GenerateDBCachePrimary(ctx context.Context) (*DBCache, error) { 20 | var err error 21 | dbCache := &DBCache{} 22 | dbCache.defaultSchema, err = u.repo.CurrentSchema(ctx) 23 | if err != nil { 24 | return nil, err 25 | } 26 | schemas, err := u.genSchemaCache(ctx) 27 | if err != nil { 28 | return nil, err 29 | } 30 | dbCache.Schemas = make(map[string]string) 31 | for index, element := range schemas { 32 | dbCache.Schemas[strings.ToUpper(index)] = element 33 | } 34 | 35 | if dbCache.defaultSchema == "" { 36 | var topKey string 37 | for k := range dbCache.Schemas { 38 | topKey = k 39 | continue 40 | } 41 | dbCache.defaultSchema = dbCache.Schemas[topKey] 42 | } 43 | schemaTables, err := u.repo.SchemaTables(ctx) 44 | if err != nil { 45 | return nil, err 46 | } 47 | dbCache.SchemaTables = make(map[string][]string) 48 | for index, element := range schemaTables { 49 | dbCache.SchemaTables[strings.ToUpper(index)] = element 50 | } 51 | 52 | dbCache.ColumnsWithParent, err = u.genColumnCacheCurrent(ctx, dbCache.defaultSchema) 53 | if err != nil { 54 | return nil, err 55 | } 56 | dbCache.ForeignKeys, err = u.genForeignKeysCache(ctx, dbCache.defaultSchema) 57 | if err != nil { 58 | return nil, err 59 | } 60 | return dbCache, nil 61 | } 62 | 63 | func (u *DBCacheGenerator) GenerateDBCacheSecondary(ctx context.Context) (map[string][]*ColumnDesc, error) { 64 | return u.genColumnCacheAll(ctx) 65 | } 66 | 67 | func (u *DBCacheGenerator) genSchemaCache(ctx context.Context) (map[string]string, error) { 68 | dbs, err := u.repo.Schemas(ctx) 69 | if err != nil { 70 | return nil, err 71 | } 72 | databaseMap := map[string]string{} 73 | for _, db := range dbs { 74 | databaseMap[strings.ToUpper(db)] = db 75 | } 76 | return databaseMap, nil 77 | } 78 | 79 | func (u *DBCacheGenerator) genColumnCacheCurrent(ctx context.Context, schemaName string) (map[string][]*ColumnDesc, error) { 80 | columnDescs, err := u.repo.DescribeDatabaseTableBySchema(ctx, schemaName) 81 | if err != nil { 82 | return nil, err 83 | } 84 | return genColumnMap(columnDescs), nil 85 | } 86 | 87 | func (u *DBCacheGenerator) genColumnCacheAll(ctx context.Context) (map[string][]*ColumnDesc, error) { 88 | columnDescs, err := u.repo.DescribeDatabaseTable(ctx) 89 | if err != nil { 90 | return nil, err 91 | } 92 | return genColumnMap(columnDescs), nil 93 | } 94 | 95 | func (u *DBCacheGenerator) genForeignKeysCache(ctx context.Context, schemaName string) (map[string]map[string][]*ForeignKey, error) { 96 | retVal := make(map[string]map[string][]*ForeignKey) 97 | fk, err := u.repo.DescribeForeignKeysBySchema(ctx, schemaName) 98 | if err != nil { 99 | return nil, err 100 | } 101 | 102 | for _, cur := range fk { 103 | elem := (*cur)[0] 104 | refs, ok := retVal[elem[0].Table] 105 | if !ok { 106 | refs = make(map[string][]*ForeignKey) 107 | } 108 | refs[elem[1].Table] = append(refs[elem[1].Table], cur) 109 | retVal[elem[0].Table] = refs 110 | 111 | refs, ok = retVal[elem[1].Table] 112 | if !ok { 113 | refs = make(map[string][]*ForeignKey) 114 | } 115 | refs[elem[0].Table] = append(refs[elem[0].Table], cur) 116 | retVal[elem[1].Table] = refs 117 | } 118 | return retVal, nil 119 | } 120 | 121 | func genColumnMap(columnDescs []*ColumnDesc) map[string][]*ColumnDesc { 122 | columnMap := map[string][]*ColumnDesc{} 123 | for _, desc := range columnDescs { 124 | key := columnDatabaseKey(desc.Schema, desc.Table) 125 | columnMap[key] = append(columnMap[key], desc) 126 | } 127 | return columnMap 128 | } 129 | 130 | type DBCache struct { 131 | defaultSchema string 132 | Schemas map[string]string 133 | SchemaTables map[string][]string 134 | ColumnsWithParent map[string][]*ColumnDesc 135 | ForeignKeys map[string]map[string][]*ForeignKey 136 | } 137 | 138 | func (dc *DBCache) Database(dbName string) (db string, ok bool) { 139 | db, ok = dc.Schemas[strings.ToUpper(dbName)] 140 | return 141 | } 142 | 143 | func (dc *DBCache) SortedSchemas() []string { 144 | dbs := []string{} 145 | for _, db := range dc.Schemas { 146 | dbs = append(dbs, db) 147 | } 148 | sort.Strings(dbs) 149 | return dbs 150 | } 151 | 152 | func (dc *DBCache) SortedTablesByDBName(dbName string) (tbls []string, ok bool) { 153 | tbls, ok = dc.SchemaTables[strings.ToUpper(dbName)] 154 | sort.Strings(tbls) 155 | return 156 | } 157 | 158 | func (dc *DBCache) SortedTables() []string { 159 | tbls, _ := dc.SortedTablesByDBName(dc.defaultSchema) 160 | return tbls 161 | } 162 | 163 | func (dc *DBCache) ColumnDescs(tableName string) (cols []*ColumnDesc, ok bool) { 164 | cols, ok = dc.ColumnsWithParent[columnDatabaseKey(dc.defaultSchema, tableName)] 165 | return 166 | } 167 | 168 | func (dc *DBCache) ColumnDatabase(dbName, tableName string) (cols []*ColumnDesc, ok bool) { 169 | cols, ok = dc.ColumnsWithParent[columnDatabaseKey(dbName, tableName)] 170 | return 171 | } 172 | 173 | func (dc *DBCache) Column(tableName, colName string) (*ColumnDesc, bool) { 174 | cols, ok := dc.ColumnsWithParent[columnDatabaseKey(dc.defaultSchema, tableName)] 175 | if !ok { 176 | return nil, false 177 | } 178 | for _, col := range cols { 179 | if strings.EqualFold(col.Name, colName) { 180 | return col, true 181 | } 182 | } 183 | return nil, false 184 | } 185 | 186 | func columnDatabaseKey(dbName, tableName string) string { 187 | return strings.ToUpper(dbName) + "\t" + strings.ToUpper(tableName) 188 | } 189 | -------------------------------------------------------------------------------- /internal/database/clickhouse_test.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import "testing" 4 | 5 | func TestGenClickhouseDsn(t *testing.T) { 6 | type testCase struct { 7 | name string 8 | connCfg *DBConfig 9 | want string 10 | wantErr bool 11 | } 12 | 13 | tests := []testCase{ 14 | { 15 | name: "use datasource name", 16 | connCfg: &DBConfig{ 17 | DataSourceName: "clickhouse://user:pwd@localhost:9001", 18 | Driver: "clickhouse", 19 | }, 20 | want: "clickhouse://user:pwd@localhost:9001", 21 | wantErr: false, 22 | }, 23 | { 24 | name: "use config properties", 25 | connCfg: &DBConfig{ 26 | Alias: "", 27 | DataSourceName: "", 28 | Driver: "clickhouse", 29 | Proto: "tcp", 30 | User: "test", 31 | Passwd: "secure", 32 | Host: "localhost", 33 | Port: 9001, 34 | Path: "", 35 | DBName: "default", 36 | Params: map[string]string{ 37 | "dial_timeout": "200ms", 38 | }, 39 | }, 40 | want: "clickhouse://test:secure@localhost:9001/default?dial_timeout=200ms", 41 | }, 42 | } 43 | 44 | for _, tt := range tests { 45 | t.Run(tt.name, func(t *testing.T) { 46 | got, err := genClickhouseDsn(tt.connCfg) 47 | 48 | if (err != nil) != tt.wantErr { 49 | t.Errorf("genClickhouseDsn() error = %v, wantErr %v", err, tt.wantErr) 50 | return 51 | } 52 | if got != tt.want { 53 | t.Errorf("got %q, want %q", got, tt.want) 54 | } 55 | }) 56 | 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /internal/database/config.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/sqls-server/sqls/dialect" 9 | "golang.org/x/crypto/ssh" 10 | ) 11 | 12 | type Proto string 13 | 14 | const ( 15 | ProtoTCP Proto = "tcp" 16 | ProtoUDP Proto = "udp" 17 | ProtoUnix Proto = "unix" 18 | ProtoHTTP Proto = "http" 19 | ) 20 | 21 | type DBConfig struct { 22 | Alias string `json:"alias" yaml:"alias"` 23 | Driver dialect.DatabaseDriver `json:"driver" yaml:"driver"` 24 | DataSourceName string `json:"dataSourceName" yaml:"dataSourceName"` 25 | Proto Proto `json:"proto" yaml:"proto"` 26 | User string `json:"user" yaml:"user"` 27 | Passwd string `json:"passwd" yaml:"passwd"` 28 | Host string `json:"host" yaml:"host"` 29 | Port int `json:"port" yaml:"port"` 30 | Path string `json:"path" yaml:"path"` 31 | DBName string `json:"dbName" yaml:"dbName"` 32 | Params map[string]string `json:"params" yaml:"params"` 33 | SSHCfg *SSHConfig `json:"sshConfig" yaml:"sshConfig"` 34 | } 35 | 36 | func (c *DBConfig) Validate() error { 37 | if c.Driver == "" { 38 | return errors.New("required: connections[].driver") 39 | } 40 | 41 | switch c.Driver { 42 | case 43 | dialect.DatabaseDriverMySQL, 44 | dialect.DatabaseDriverMySQL8, 45 | dialect.DatabaseDriverMySQL57, 46 | dialect.DatabaseDriverMySQL56, 47 | dialect.DatabaseDriverPostgreSQL, 48 | dialect.DatabaseDriverVertica: 49 | if c.DataSourceName == "" && c.Proto == "" { 50 | return errors.New("required: connections[].dataSourceName or connections[].proto") 51 | } 52 | 53 | if c.DataSourceName == "" && c.Proto != "" { 54 | if c.User == "" { 55 | return errors.New("required: connections[].user") 56 | } 57 | switch c.Proto { 58 | case ProtoTCP, ProtoUDP, ProtoHTTP: 59 | if c.Host == "" { 60 | return errors.New("required: connections[].host") 61 | } 62 | case ProtoUnix: 63 | if c.Path == "" { 64 | return errors.New("required: connections[].path") 65 | } 66 | default: 67 | return errors.New("invalid: connections[].proto") 68 | } 69 | if c.SSHCfg != nil { 70 | return c.SSHCfg.Validate() 71 | } 72 | } 73 | case dialect.DatabaseDriverSQLite3: 74 | case dialect.DatabaseDriverH2: 75 | if c.DataSourceName == "" { 76 | return errors.New("required: connections[].dataSourceName") 77 | } 78 | case dialect.DatabaseDriverMssql: 79 | if c.DataSourceName == "" && c.Proto == "" { 80 | return errors.New("required: connections[].dataSourceName or connections[].proto") 81 | } 82 | if c.DataSourceName == "" && c.Proto != "" { 83 | if c.User == "" { 84 | return errors.New("required: connections[].user") 85 | } 86 | switch c.Proto { 87 | case ProtoTCP: 88 | if c.Host == "" { 89 | return errors.New("required: connections[].host") 90 | } 91 | case ProtoUDP, ProtoUnix, ProtoHTTP: 92 | default: 93 | return errors.New("invalid: connections[].proto") 94 | } 95 | } 96 | case dialect.DatabaseDriverOracle: 97 | if c.DataSourceName == "" && c.Proto == "" { 98 | return errors.New("required: connections[].dataSourceName or connections[].proto") 99 | } 100 | if c.DataSourceName == "" { 101 | if c.User == "" { 102 | return errors.New("required: connections[].user") 103 | } 104 | if c.Passwd == "" { 105 | return errors.New("required: connections[].Passwd") 106 | } 107 | if c.Host == "" { 108 | return errors.New("required: connections[].Host") 109 | } 110 | if c.Port <= 0 { 111 | return errors.New("required: connections[].Port") 112 | } 113 | if c.DBName == "" { 114 | return errors.New("required: connections[].DBName") 115 | } 116 | } 117 | case dialect.DatabaseDriverClickhouse: 118 | if c.DataSourceName == "" && c.Proto == "" { 119 | return errors.New("required: connections[].dataSourceName or connections[].proto") 120 | } 121 | 122 | if c.DataSourceName == "" && c.Proto != "" { 123 | if c.User == "" { 124 | return errors.New("required: connections[].user") 125 | } 126 | switch c.Proto { 127 | case ProtoTCP, ProtoHTTP: 128 | if c.Host == "" { 129 | return errors.New("required: connections[].host") 130 | } 131 | case ProtoUDP, ProtoUnix: 132 | default: 133 | return errors.New("invalid: connections[].proto") 134 | } 135 | if c.SSHCfg != nil { 136 | return c.SSHCfg.Validate() 137 | } 138 | } 139 | 140 | default: 141 | return errors.New("invalid: connections[].driver") 142 | } 143 | return nil 144 | } 145 | 146 | type SSHConfig struct { 147 | Host string `json:"host" yaml:"host"` 148 | Port int `json:"port" yaml:"port"` 149 | User string `json:"user" yaml:"user"` 150 | PassPhrase string `json:"passPhrase" yaml:"passPhrase"` 151 | PrivateKey string `json:"privateKey" yaml:"privateKey"` 152 | } 153 | 154 | func (s *SSHConfig) Validate() error { 155 | if s.Host == "" { 156 | return errors.New("required: connections[]sshConfig.host") 157 | } 158 | if s.User == "" { 159 | return errors.New("required: connections[].sshConfig.user") 160 | } 161 | if s.PrivateKey == "" { 162 | return errors.New("required: connections[].sshConfig.privateKey") 163 | } 164 | return nil 165 | } 166 | 167 | func (s *SSHConfig) Endpoint() string { 168 | return fmt.Sprintf("%s:%d", s.Host, s.Port) 169 | } 170 | 171 | func (s *SSHConfig) ClientConfig() (*ssh.ClientConfig, error) { 172 | buffer, err := os.ReadFile(s.PrivateKey) 173 | if err != nil { 174 | return nil, fmt.Errorf("cannot read SSH private key file, PrivateKey=%s, %w", s.PrivateKey, err) 175 | } 176 | 177 | var key ssh.Signer 178 | if s.PassPhrase != "" { 179 | key, err = ssh.ParsePrivateKeyWithPassphrase(buffer, []byte(s.PassPhrase)) 180 | if err != nil { 181 | return nil, fmt.Errorf("cannot parse SSH private key file with passphrase, PrivateKey=%s, %w", s.PrivateKey, err) 182 | } 183 | } else { 184 | key, err = ssh.ParsePrivateKey(buffer) 185 | if err != nil { 186 | return nil, fmt.Errorf("cannot parse SSH private key file, PrivateKey=%s, %w", s.PrivateKey, err) 187 | } 188 | } 189 | 190 | sshConfig := &ssh.ClientConfig{ 191 | User: s.User, 192 | Auth: []ssh.AuthMethod{ssh.PublicKeys(key)}, 193 | HostKeyCallback: ssh.InsecureIgnoreHostKey(), 194 | } 195 | return sshConfig, nil 196 | } 197 | -------------------------------------------------------------------------------- /internal/database/database.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "database/sql" 7 | "errors" 8 | "fmt" 9 | "strings" 10 | 11 | "github.com/sqls-server/sqls/dialect" 12 | "github.com/sqls-server/sqls/parser/parseutil" 13 | ) 14 | 15 | var ( 16 | ErrNotImplementation error = errors.New("not implementation") 17 | ) 18 | 19 | const ( 20 | DefaultMaxIdleConns = 10 21 | DefaultMaxOpenConns = 5 22 | ) 23 | 24 | type DBRepository interface { 25 | Driver() dialect.DatabaseDriver 26 | CurrentDatabase(ctx context.Context) (string, error) 27 | Databases(ctx context.Context) ([]string, error) 28 | CurrentSchema(ctx context.Context) (string, error) 29 | Schemas(ctx context.Context) ([]string, error) 30 | SchemaTables(ctx context.Context) (map[string][]string, error) 31 | DescribeDatabaseTable(ctx context.Context) ([]*ColumnDesc, error) 32 | DescribeDatabaseTableBySchema(ctx context.Context, schemaName string) ([]*ColumnDesc, error) 33 | Exec(ctx context.Context, query string) (sql.Result, error) 34 | Query(ctx context.Context, query string) (*sql.Rows, error) 35 | DescribeForeignKeysBySchema(ctx context.Context, schemaName string) ([]*ForeignKey, error) 36 | } 37 | 38 | type DBOption struct { 39 | MaxIdleConns int 40 | MaxOpenConns int 41 | } 42 | 43 | type ColumnBase struct { 44 | Schema string 45 | Table string 46 | Name string 47 | } 48 | 49 | type ColumnDesc struct { 50 | ColumnBase 51 | Type string 52 | Null string 53 | Key string 54 | Default sql.NullString 55 | Extra string 56 | } 57 | 58 | type ForeignKey [][2]*ColumnBase 59 | 60 | type fkItemDesc struct { 61 | fkID string 62 | schema string 63 | table string 64 | column string 65 | refTable string 66 | refColumn string 67 | } 68 | 69 | func (cd *ColumnDesc) OnelineDesc() string { 70 | items := []string{} 71 | if cd.Type != "" { 72 | items = append(items, "`"+cd.Type+"`") 73 | } 74 | if cd.Key == "YES" { 75 | items = append(items, "PRIMARY KEY") 76 | } else if cd.Key != "" && cd.Key != "NO" { 77 | items = append(items, cd.Key) 78 | } 79 | if cd.Extra != "" { 80 | items = append(items, cd.Extra) 81 | } 82 | return strings.Join(items, " ") 83 | } 84 | 85 | func ColumnDoc(tableName string, colDesc *ColumnDesc) string { 86 | buf := new(bytes.Buffer) 87 | fmt.Fprintf(buf, "`%s`.`%s` column", tableName, colDesc.Name) 88 | fmt.Fprintln(buf) 89 | fmt.Fprintln(buf) 90 | fmt.Fprintln(buf, colDesc.OnelineDesc()) 91 | return buf.String() 92 | } 93 | 94 | func Coalesce(str ...string) string { 95 | for _, s := range str { 96 | if s != "" { 97 | return s 98 | } 99 | } 100 | return "" 101 | } 102 | 103 | func TableDoc(tableName string, cols []*ColumnDesc) string { 104 | buf := new(bytes.Buffer) 105 | fmt.Fprintf(buf, "# `%s` table", tableName) 106 | fmt.Fprintln(buf) 107 | fmt.Fprintln(buf) 108 | fmt.Fprintln(buf) 109 | fmt.Fprintln(buf, "| Name   | Type   | Primary key   | Default   | Extra   |") 110 | fmt.Fprintln(buf, "| :--------------- | :--------------- | :---------------------- | :------------------ | :---------------- |") 111 | for _, col := range cols { 112 | fmt.Fprintf(buf, "| `%s` | `%s` | `%s` | `%s` | %s |", col.Name, col.Type, col.Key, Coalesce(col.Default.String, "-"), col.Extra) 113 | fmt.Fprintln(buf) 114 | } 115 | return buf.String() 116 | } 117 | 118 | func SubqueryDoc(name string, views []*parseutil.SubQueryView, dbCache *DBCache) string { 119 | buf := new(bytes.Buffer) 120 | fmt.Fprintf(buf, "%s subquery", name) 121 | fmt.Fprintln(buf) 122 | fmt.Fprintln(buf) 123 | for _, view := range views { 124 | for _, colmun := range view.SubQueryColumns { 125 | if colmun.ColumnName == "*" { 126 | tableCols, ok := dbCache.ColumnDescs(colmun.ParentTable.Name) 127 | if !ok { 128 | continue 129 | } 130 | for _, tableCol := range tableCols { 131 | fmt.Fprintf(buf, "- %s(%s.%s): %s", tableCol.Name, colmun.ParentTable.Name, tableCol.Name, tableCol.OnelineDesc()) 132 | fmt.Fprintln(buf) 133 | } 134 | } else { 135 | columnDesc, ok := dbCache.Column(colmun.ParentTable.Name, colmun.ColumnName) 136 | if !ok { 137 | continue 138 | } 139 | fmt.Fprintf(buf, "- %s(%s.%s): %s", colmun.DisplayName(), colmun.ParentTable.Name, colmun.ColumnName, columnDesc.OnelineDesc()) 140 | fmt.Fprintln(buf) 141 | 142 | } 143 | } 144 | } 145 | return buf.String() 146 | } 147 | 148 | func SubqueryColumnDoc(identName string, views []*parseutil.SubQueryView, dbCache *DBCache) string { 149 | buf := new(bytes.Buffer) 150 | fmt.Fprintf(buf, "%s subquery column", identName) 151 | fmt.Fprintln(buf) 152 | fmt.Fprintln(buf) 153 | for _, view := range views { 154 | for _, colmun := range view.SubQueryColumns { 155 | if colmun.ColumnName == "*" { 156 | tableCols, ok := dbCache.ColumnDescs(colmun.ParentTable.Name) 157 | if !ok { 158 | continue 159 | } 160 | for _, tableCol := range tableCols { 161 | if identName == tableCol.Name { 162 | fmt.Fprintf(buf, "- %s(%s.%s): %s", identName, colmun.ParentTable.Name, tableCol.Name, tableCol.OnelineDesc()) 163 | fmt.Fprintln(buf) 164 | continue 165 | } 166 | } 167 | } else { 168 | if identName != colmun.ColumnName && identName != colmun.AliasName { 169 | continue 170 | } 171 | columnDesc, ok := dbCache.Column(colmun.ParentTable.Name, colmun.ColumnName) 172 | if !ok { 173 | continue 174 | } 175 | fmt.Fprintf(buf, "- %s(%s.%s): %s", identName, colmun.ParentTable.Name, colmun.ColumnName, columnDesc.OnelineDesc()) 176 | fmt.Fprintln(buf) 177 | } 178 | } 179 | } 180 | return buf.String() 181 | } 182 | 183 | func parseForeignKeys(rows *sql.Rows, schemaName string) ([]*ForeignKey, error) { 184 | var retVal []*ForeignKey 185 | var prevFk string 186 | var cur *ForeignKey 187 | for rows.Next() { 188 | var fkItem fkItemDesc 189 | err := rows.Scan( 190 | &fkItem.fkID, 191 | &fkItem.table, 192 | &fkItem.column, 193 | &fkItem.refTable, 194 | &fkItem.refColumn, 195 | ) 196 | if err != nil { 197 | return nil, err 198 | } 199 | var l, r ColumnBase 200 | l.Schema = schemaName 201 | l.Table = fkItem.table 202 | l.Name = fkItem.column 203 | r.Schema = l.Schema 204 | r.Table = fkItem.refTable 205 | r.Name = fkItem.refColumn 206 | if fkItem.fkID != prevFk { 207 | if cur != nil { 208 | retVal = append(retVal, cur) 209 | } 210 | cur = new(ForeignKey) 211 | } 212 | *cur = append(*cur, [2]*ColumnBase{&l, &r}) 213 | prevFk = fkItem.fkID 214 | } 215 | 216 | if cur != nil { 217 | retVal = append(retVal, cur) 218 | } 219 | return retVal, nil 220 | } 221 | -------------------------------------------------------------------------------- /internal/database/driver.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | 7 | "github.com/sqls-server/sqls/dialect" 8 | "golang.org/x/crypto/ssh" 9 | ) 10 | 11 | var driverOpeners = make(map[dialect.DatabaseDriver]Opener) 12 | var driverFactories = make(map[dialect.DatabaseDriver]Factory) 13 | 14 | type Opener func(*DBConfig) (*DBConnection, error) 15 | type Factory func(*sql.DB) DBRepository 16 | 17 | type DBConnection struct { 18 | Conn *sql.DB 19 | SSHConn *ssh.Client 20 | Driver dialect.DatabaseDriver 21 | } 22 | 23 | func (db *DBConnection) Close() error { 24 | if db == nil { 25 | return nil 26 | } 27 | if err := db.Conn.Close(); err != nil { 28 | return err 29 | } 30 | if db.SSHConn != nil { 31 | if err := db.SSHConn.Close(); err != nil { 32 | return err 33 | } 34 | } 35 | return nil 36 | } 37 | 38 | func RegisterOpen(name dialect.DatabaseDriver, opener Opener) { 39 | if _, ok := driverOpeners[name]; ok { 40 | panic(fmt.Sprintf("driver open %s method is already registered", name)) 41 | } 42 | driverOpeners[name] = opener 43 | } 44 | 45 | func RegisterFactory(name dialect.DatabaseDriver, factory Factory) { 46 | if _, ok := driverFactories[name]; ok { 47 | panic(fmt.Sprintf("driver factory %s already registered", name)) 48 | } 49 | driverFactories[name] = factory 50 | } 51 | 52 | func Registered(name dialect.DatabaseDriver) bool { 53 | _, ok1 := driverOpeners[name] 54 | _, ok2 := driverFactories[name] 55 | return ok1 && ok2 56 | } 57 | 58 | func Open(cfg *DBConfig) (*DBConnection, error) { 59 | OpenFn, ok := driverOpeners[cfg.Driver] 60 | if !ok { 61 | return nil, fmt.Errorf("driver not found, %s", cfg.Driver) 62 | } 63 | return OpenFn(cfg) 64 | } 65 | 66 | func CreateRepository(driver dialect.DatabaseDriver, db *sql.DB) (DBRepository, error) { 67 | FactoryFn, ok := driverFactories[driver] 68 | if !ok { 69 | return nil, fmt.Errorf("driver not found, %s", driver) 70 | } 71 | return FactoryFn(db), nil 72 | } 73 | -------------------------------------------------------------------------------- /internal/database/h2.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "log" 8 | 9 | _ "github.com/CodinGame/h2go" 10 | "github.com/sqls-server/sqls/dialect" 11 | ) 12 | 13 | func init() { 14 | RegisterOpen("h2", h2Open) 15 | RegisterFactory("h2", NewH2DBRepository) 16 | } 17 | 18 | func h2Open(dbConnCfg *DBConfig) (*DBConnection, error) { 19 | var ( 20 | conn *sql.DB 21 | ) 22 | cfg, err := genH2Config(dbConnCfg) 23 | if err != nil { 24 | return nil, err 25 | } 26 | 27 | if dbConnCfg.SSHCfg != nil { 28 | return nil, fmt.Errorf("connect via SSH is not supported") 29 | } 30 | dbConn, err := sql.Open("h2", cfg) 31 | if err != nil { 32 | return nil, err 33 | } 34 | conn = dbConn 35 | 36 | return &DBConnection{ 37 | Conn: conn, 38 | Driver: dbConnCfg.Driver, 39 | }, nil 40 | } 41 | 42 | func genH2Config(connCfg *DBConfig) (string, error) { 43 | if connCfg.DataSourceName != "" { 44 | return connCfg.DataSourceName, nil 45 | } 46 | 47 | return "", fmt.Errorf("Only DataSourceName is supported") 48 | } 49 | 50 | type H2DBRepository struct { 51 | Conn *sql.DB 52 | driver dialect.DatabaseDriver 53 | } 54 | 55 | func NewH2DBRepository(conn *sql.DB) DBRepository { 56 | return &H2DBRepository{Conn: conn} 57 | } 58 | 59 | func (db *H2DBRepository) Driver() dialect.DatabaseDriver { 60 | return db.driver 61 | } 62 | 63 | func (db *H2DBRepository) CurrentDatabase(ctx context.Context) (string, error) { 64 | return "", nil 65 | } 66 | 67 | func (db *H2DBRepository) Databases(ctx context.Context) ([]string, error) { 68 | return []string{}, nil 69 | } 70 | 71 | func (db *H2DBRepository) CurrentSchema(ctx context.Context) (string, error) { 72 | return "PUBLIC", nil 73 | } 74 | 75 | func (db *H2DBRepository) Schemas(ctx context.Context) ([]string, error) { 76 | rows, err := db.Conn.QueryContext( 77 | ctx, 78 | ` 79 | SELECT schema_name FROM information_schema.schemata 80 | `) 81 | if err != nil { 82 | log.Fatal(err) 83 | } 84 | defer rows.Close() 85 | schemas := []string{} 86 | for rows.Next() { 87 | var schema string 88 | if err := rows.Scan(&schema); err != nil { 89 | return nil, err 90 | } 91 | schemas = append(schemas, schema) 92 | } 93 | return schemas, nil 94 | } 95 | 96 | func (db *H2DBRepository) SchemaTables(ctx context.Context) (map[string][]string, error) { 97 | rows, err := db.Conn.QueryContext( 98 | ctx, 99 | ` 100 | SELECT 101 | table_schema, 102 | table_name 103 | FROM 104 | information_schema.tables 105 | ORDER BY 106 | table_schema, 107 | table_name 108 | `) 109 | if err != nil { 110 | return nil, err 111 | } 112 | defer rows.Close() 113 | databaseTables := map[string][]string{} 114 | for rows.Next() { 115 | var schema, table string 116 | if err := rows.Scan(&schema, &table); err != nil { 117 | return nil, err 118 | } 119 | 120 | if arr, ok := databaseTables[schema]; ok { 121 | databaseTables[schema] = append(arr, table) 122 | } else { 123 | databaseTables[schema] = []string{table} 124 | } 125 | } 126 | return databaseTables, nil 127 | } 128 | 129 | func (db *H2DBRepository) Tables(ctx context.Context) ([]string, error) { 130 | 131 | rows, err := db.Conn.QueryContext( 132 | ctx, 133 | ` 134 | SELECT 135 | table_name 136 | FROM 137 | information_schema.tables 138 | WHERE 139 | table_schema NOT IN ('INFORMATION_SCHEMA') 140 | ORDER BY 141 | table_name 142 | `) 143 | if err != nil { 144 | log.Fatal(err) 145 | } 146 | defer rows.Close() 147 | tables := []string{} 148 | for rows.Next() { 149 | var table string 150 | if err := rows.Scan(&table); err != nil { 151 | return nil, err 152 | } 153 | tables = append(tables, table) 154 | } 155 | return tables, nil 156 | } 157 | 158 | func (db *H2DBRepository) DescribeDatabaseTable(ctx context.Context) ([]*ColumnDesc, error) { 159 | rows, err := db.Conn.QueryContext( 160 | ctx, 161 | ` 162 | SELECT 163 | c.table_schema, 164 | c.table_name, 165 | c.column_name, 166 | c.type_name, 167 | c.is_nullable, 168 | CASE tc.constraint_type 169 | WHEN 'PRIMARY KEY' THEN 'YES' 170 | ELSE 'NO' 171 | END, 172 | c.column_default, 173 | '' 174 | FROM 175 | information_schema.columns c 176 | LEFT JOIN 177 | information_schema.constraints tc 178 | ON c.table_schema = tc.table_schema 179 | AND c.table_name = tc.table_name 180 | AND REGEXP_LIKE(tc.column_list, '(^|,)' || c.column_name || '(,|$)', 'i') 181 | ORDER BY 182 | c.table_name, 183 | c.ordinal_position 184 | `) 185 | if err != nil { 186 | log.Fatal(err) 187 | } 188 | defer rows.Close() 189 | tableInfos := []*ColumnDesc{} 190 | for rows.Next() { 191 | var tableInfo ColumnDesc 192 | err := rows.Scan( 193 | &tableInfo.Schema, 194 | &tableInfo.Table, 195 | &tableInfo.Name, 196 | &tableInfo.Type, 197 | &tableInfo.Null, 198 | &tableInfo.Key, 199 | &tableInfo.Default, 200 | &tableInfo.Extra, 201 | ) 202 | if err != nil { 203 | return nil, err 204 | } 205 | tableInfos = append(tableInfos, &tableInfo) 206 | } 207 | return tableInfos, nil 208 | } 209 | 210 | func (db *H2DBRepository) DescribeDatabaseTableBySchema(ctx context.Context, schemaName string) ([]*ColumnDesc, error) { 211 | // h2go doesn't support NamedValue yet 212 | rows, err := db.Conn.QueryContext( 213 | ctx, 214 | fmt.Sprintf(` 215 | SELECT 216 | c.table_schema, 217 | c.table_name, 218 | c.column_name, 219 | c.type_name, 220 | c.is_nullable, 221 | CASE tc.constraint_type 222 | WHEN 'PRIMARY KEY' THEN 'YES' 223 | ELSE 'NO' 224 | END, 225 | c.column_default, 226 | '' 227 | FROM 228 | information_schema.columns c 229 | LEFT JOIN 230 | information_schema.constraints tc 231 | ON c.table_schema = tc.table_schema 232 | AND c.table_name = tc.table_name 233 | AND REGEXP_LIKE(tc.column_list, '(^|,)' || c.column_name || '(,|$)', 'i') 234 | WHERE 235 | c.table_schema = '%s' 236 | ORDER BY 237 | c.table_name, 238 | c.ordinal_position 239 | `, schemaName)) 240 | if err != nil { 241 | log.Fatal(err) 242 | } 243 | defer rows.Close() 244 | tableInfos := []*ColumnDesc{} 245 | for rows.Next() { 246 | var tableInfo ColumnDesc 247 | err := rows.Scan( 248 | &tableInfo.Schema, 249 | &tableInfo.Table, 250 | &tableInfo.Name, 251 | &tableInfo.Type, 252 | &tableInfo.Null, 253 | &tableInfo.Key, 254 | &tableInfo.Default, 255 | &tableInfo.Extra, 256 | ) 257 | if err != nil { 258 | return nil, err 259 | } 260 | tableInfos = append(tableInfos, &tableInfo) 261 | } 262 | return tableInfos, nil 263 | } 264 | 265 | func (db *H2DBRepository) Exec(ctx context.Context, query string) (sql.Result, error) { 266 | return db.Conn.ExecContext(ctx, query) 267 | } 268 | 269 | func (db *H2DBRepository) Query(ctx context.Context, query string) (*sql.Rows, error) { 270 | return db.Conn.QueryContext(ctx, query) 271 | } 272 | 273 | func (db *H2DBRepository) DescribeForeignKeysBySchema(ctx context.Context, schemaName string) ([]*ForeignKey, error) { 274 | return nil, fmt.Errorf("describe foreign keys is not supported") 275 | } 276 | -------------------------------------------------------------------------------- /internal/database/mssql_test.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "testing" 5 | 6 | _ "github.com/denisenkom/go-mssqldb" 7 | ) 8 | 9 | func Test_genMssqlConfig(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | connCfg *DBConfig 13 | want string 14 | wantErr bool 15 | }{ 16 | { 17 | name: "", 18 | connCfg: &DBConfig{ 19 | Alias: "", 20 | Driver: "mssql", 21 | DataSourceName: "", 22 | Proto: "tcp", 23 | User: "sa", 24 | Passwd: "mysecretpassword1234", 25 | Host: "127.0.0.1", 26 | Port: 11433, 27 | Path: "", 28 | DBName: "dvdrental", 29 | Params: map[string]string{ 30 | "ApplicationIntent": "ReadOnly", 31 | }, 32 | }, 33 | want: "ApplicationIntent=ReadOnly;database=dvdrental;password=mysecretpassword1234;port=11433;server=127.0.0.1;user=sa", 34 | wantErr: false, 35 | }, 36 | } 37 | for _, tt := range tests { 38 | t.Run(tt.name, func(t *testing.T) { 39 | got, err := genMssqlConfig(tt.connCfg) 40 | if (err != nil) != tt.wantErr { 41 | t.Errorf("genMssqlConfig() error = %v, wantErr %v", err, tt.wantErr) 42 | return 43 | } 44 | if got != tt.want { 45 | t.Errorf("got %q, want %q", got, tt.want) 46 | } 47 | }) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /internal/database/oracle.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "log" 7 | "strconv" 8 | 9 | _ "github.com/godror/godror" 10 | "github.com/sqls-server/sqls/dialect" 11 | ) 12 | 13 | func init() { 14 | RegisterOpen("oracle", oracleOpen) 15 | RegisterFactory("oracle", NewOracleDBRepository) 16 | } 17 | 18 | func oracleOpen(dbConnCfg *DBConfig) (*DBConnection, error) { 19 | var ( 20 | conn *sql.DB 21 | ) 22 | DSName, err := genOracleConfig(dbConnCfg) 23 | if err != nil { 24 | return nil, err 25 | } 26 | 27 | conn, err = sql.Open("godror", DSName) 28 | if err != nil { 29 | return nil, err 30 | } 31 | 32 | conn.SetMaxIdleConns(DefaultMaxIdleConns) 33 | conn.SetMaxOpenConns(DefaultMaxOpenConns) 34 | 35 | return &DBConnection{ 36 | Conn: conn, 37 | Driver: dialect.DatabaseDriverOracle, 38 | }, nil 39 | } 40 | 41 | func genOracleConfig(connCfg *DBConfig) (string, error) { 42 | if connCfg.DataSourceName != "" { 43 | return connCfg.DataSourceName, nil 44 | } 45 | 46 | host, port := connCfg.Host, connCfg.Port 47 | if host == "" { 48 | host = "127.0.0.1" 49 | } 50 | if port == 0 { 51 | port = 1521 52 | } 53 | DSName := connCfg.User + "/" + connCfg.Passwd + "@" + host + ":" + strconv.Itoa(port) + "/" + connCfg.DBName 54 | return DSName, nil 55 | } 56 | 57 | type OracleDBRepository struct { 58 | Conn *sql.DB 59 | } 60 | 61 | func NewOracleDBRepository(conn *sql.DB) DBRepository { 62 | return &OracleDBRepository{Conn: conn} 63 | } 64 | 65 | func (db *OracleDBRepository) Driver() dialect.DatabaseDriver { 66 | return dialect.DatabaseDriverOracle 67 | } 68 | 69 | func (db *OracleDBRepository) CurrentDatabase(ctx context.Context) (string, error) { 70 | row := db.Conn.QueryRowContext(ctx, "SELECT SYS_CONTEXT('USERENV','CURRENT_SCHEMA') FROM DUAL") 71 | var database string 72 | if err := row.Scan(&database); err != nil { 73 | return "", err 74 | } 75 | return database, nil 76 | } 77 | 78 | func (db *OracleDBRepository) Databases(ctx context.Context) ([]string, error) { 79 | // one DB per connection for Oracle 80 | rows, err := db.Conn.QueryContext(ctx, "SELECT USERNAME FROM SYS.ALL_USERS ORDER BY USERNAME") 81 | if err != nil { 82 | return nil, err 83 | } 84 | defer rows.Close() 85 | databases := []string{} 86 | for rows.Next() { 87 | var database string 88 | if err := rows.Scan(&database); err != nil { 89 | return nil, err 90 | } 91 | databases = append(databases, database) 92 | } 93 | return databases, nil 94 | } 95 | 96 | func (db *OracleDBRepository) CurrentSchema(ctx context.Context) (string, error) { 97 | return db.CurrentDatabase(ctx) 98 | } 99 | 100 | func (db *OracleDBRepository) Schemas(ctx context.Context) ([]string, error) { 101 | return db.Databases(ctx) 102 | } 103 | 104 | func (db *OracleDBRepository) SchemaTables(ctx context.Context) (map[string][]string, error) { 105 | rows, err := db.Conn.QueryContext( 106 | ctx, 107 | ` 108 | SELECT OWNER, TABLE_NAME 109 | FROM SYS.ALL_TABLES 110 | ORDER BY OWNER, TABLE_NAME 111 | `) 112 | if err != nil { 113 | return nil, err 114 | } 115 | defer rows.Close() 116 | databaseTables := map[string][]string{} 117 | for rows.Next() { 118 | var schema, table string 119 | if err := rows.Scan(&schema, &table); err != nil { 120 | return nil, err 121 | } 122 | 123 | if arr, ok := databaseTables[schema]; ok { 124 | databaseTables[schema] = append(arr, table) 125 | } else { 126 | databaseTables[schema] = []string{table} 127 | } 128 | } 129 | return databaseTables, nil 130 | } 131 | 132 | func (db *OracleDBRepository) Tables(ctx context.Context) ([]string, error) { 133 | rows, err := db.Conn.QueryContext(ctx, "SELECT TABLE_NAME FROM USER_TABLES") 134 | if err != nil { 135 | return nil, err 136 | } 137 | defer rows.Close() 138 | tables := []string{} 139 | for rows.Next() { 140 | var table string 141 | if err := rows.Scan(&table); err != nil { 142 | return nil, err 143 | } 144 | tables = append(tables, table) 145 | } 146 | return tables, nil 147 | } 148 | 149 | func (db *OracleDBRepository) DescribeDatabaseTable(ctx context.Context) ([]*ColumnDesc, error) { 150 | rows, err := db.Conn.QueryContext( 151 | ctx, 152 | ` 153 | SELECT 154 | OWNER, 155 | TABLE_NAME, 156 | COLUMN_NAME, 157 | DATA_TYPE, 158 | NULLABLE, 159 | '', 160 | DATA_DEFAULT, 161 | '' 162 | FROM SYS.ALL_TAB_COLUMNS 163 | `) 164 | if err != nil { 165 | return nil, err 166 | } 167 | defer rows.Close() 168 | tableInfos := []*ColumnDesc{} 169 | for rows.Next() { 170 | var tableInfo ColumnDesc 171 | err := rows.Scan( 172 | &tableInfo.Schema, 173 | &tableInfo.Table, 174 | &tableInfo.Name, 175 | &tableInfo.Type, 176 | &tableInfo.Null, 177 | &tableInfo.Key, 178 | &tableInfo.Default, 179 | &tableInfo.Extra, 180 | ) 181 | if err != nil { 182 | return nil, err 183 | } 184 | tableInfos = append(tableInfos, &tableInfo) 185 | } 186 | return tableInfos, nil 187 | } 188 | 189 | func (db *OracleDBRepository) DescribeDatabaseTableBySchema(ctx context.Context, schemaName string) ([]*ColumnDesc, error) { 190 | rows, err := db.Conn.QueryContext( 191 | ctx, 192 | ` 193 | SELECT 194 | OWNER, 195 | TABLE_NAME, 196 | COLUMN_NAME, 197 | DATA_TYPE, 198 | CASE NULLABLE 199 | WHEN 'Y' THEN 'YES' 200 | ELSE 'NO' 201 | END, 202 | '1', 203 | DATA_DEFAULT, 204 | '1' 205 | FROM SYS.ALL_TAB_COLUMNS 206 | WHERE OWNER = :1 207 | `, schemaName) 208 | if err != nil { 209 | log.Println("schema", schemaName, err.Error()) 210 | return nil, err 211 | } 212 | tableInfos := []*ColumnDesc{} 213 | for rows.Next() { 214 | var tableInfo ColumnDesc 215 | err := rows.Scan( 216 | &tableInfo.Schema, 217 | &tableInfo.Table, 218 | &tableInfo.Name, 219 | &tableInfo.Type, 220 | &tableInfo.Null, 221 | &tableInfo.Key, 222 | &tableInfo.Default, 223 | &tableInfo.Extra, 224 | ) 225 | if err != nil { 226 | return nil, err 227 | } 228 | tableInfos = append(tableInfos, &tableInfo) 229 | } 230 | defer rows.Close() 231 | return tableInfos, nil 232 | } 233 | 234 | func (db *OracleDBRepository) DescribeForeignKeysBySchema(ctx context.Context, schemaName string) ([]*ForeignKey, error) { 235 | rows, err := db.Conn.QueryContext( 236 | ctx, 237 | ` 238 | SELECT a.CONSTRAINT_NAME, 239 | a.TABLE_NAME, 240 | a.COLUMN_NAME, 241 | b.TABLE_NAME, 242 | b.COLUMN_NAME 243 | FROM ALL_CONS_COLUMNS a 244 | JOIN ALL_CONSTRAINTS c ON a.OWNER = c.OWNER 245 | AND a.CONSTRAINT_NAME = c.CONSTRAINT_NAME 246 | JOIN ALL_CONSTRAINTS c_pk ON c.R_OWNER = c_pk.OWNER 247 | AND c.R_CONSTRAINT_NAME = c_pk.CONSTRAINT_NAME 248 | JOIN ALL_CONS_COLUMNS b ON b.CONSTRAINT_NAME = c_pk.CONSTRAINT_NAME 249 | AND b.POSITION = a.POSITION 250 | WHERE c.constraint_type = 'R' 251 | AND a.OWNER = :1 252 | ORDER BY a.CONSTRAINT_NAME, a.POSITION 253 | `, schemaName) 254 | if err != nil { 255 | log.Fatal(err) 256 | } 257 | defer func() { _ = rows.Close() }() 258 | return parseForeignKeys(rows, schemaName) 259 | } 260 | 261 | func (db *OracleDBRepository) Exec(ctx context.Context, query string) (sql.Result, error) { 262 | return db.Conn.ExecContext(ctx, query) 263 | } 264 | 265 | func (db *OracleDBRepository) Query(ctx context.Context, query string) (*sql.Rows, error) { 266 | return db.Conn.QueryContext(ctx, query) 267 | } 268 | -------------------------------------------------------------------------------- /internal/database/postgresql_test.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "testing" 5 | 6 | _ "github.com/jackc/pgx/v4/stdlib" 7 | ) 8 | 9 | func Test_genPostgresConfig(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | connCfg *DBConfig 13 | want string 14 | wantErr bool 15 | }{ 16 | { 17 | name: "", 18 | connCfg: &DBConfig{ 19 | Alias: "", 20 | Driver: "postgresql", 21 | DataSourceName: "", 22 | Proto: "tcp", 23 | User: "postgres", 24 | Passwd: "mysecretpassword1234", 25 | Host: "127.0.0.1", 26 | Port: 15432, 27 | Path: "", 28 | DBName: "dvdrental", 29 | Params: map[string]string{ 30 | "sslmode": "disable", 31 | }, 32 | }, 33 | want: "dbname=dvdrental host=127.0.0.1 password=mysecretpassword1234 port=15432 sslmode=disable user=postgres", 34 | wantErr: false, 35 | }, 36 | } 37 | for _, tt := range tests { 38 | t.Run(tt.name, func(t *testing.T) { 39 | got, err := genPostgresConfig(tt.connCfg) 40 | if (err != nil) != tt.wantErr { 41 | t.Errorf("genPostgresConfig() error = %v, wantErr %v", err, tt.wantErr) 42 | return 43 | } 44 | if got != tt.want { 45 | t.Errorf("got %q, want %q", got, tt.want) 46 | } 47 | }) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /internal/database/query_type_test.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import "testing" 4 | 5 | func TestQueryExecType(t *testing.T) { 6 | type args struct { 7 | prefix string 8 | sqlstr string 9 | } 10 | tests := []struct { 11 | name string 12 | prefix string 13 | sqlstr string 14 | wantPrefix string 15 | wantExecType bool 16 | }{ 17 | { 18 | name: "select", 19 | prefix: "select * from city", 20 | sqlstr: "", 21 | wantPrefix: "SELECT", 22 | wantExecType: true, 23 | }, 24 | { 25 | name: "select with space", 26 | prefix: " select * from city", 27 | sqlstr: "", 28 | wantPrefix: "SELECT", 29 | wantExecType: true, 30 | }, 31 | { 32 | name: "start linebreak", 33 | prefix: "\nselect * from city", 34 | sqlstr: "", 35 | wantPrefix: "SELECT", 36 | wantExecType: true, 37 | }, 38 | { 39 | name: "with linebreak", 40 | prefix: "select\n* from city", 41 | sqlstr: "", 42 | wantPrefix: "SELECT", 43 | wantExecType: true, 44 | }, 45 | { 46 | name: "start tab", 47 | prefix: "\tselect * from city", 48 | sqlstr: "", 49 | wantPrefix: "SELECT", 50 | wantExecType: true, 51 | }, 52 | { 53 | name: "with tab", 54 | prefix: "select\t* from city", 55 | sqlstr: "", 56 | wantPrefix: "SELECT", 57 | wantExecType: true, 58 | }, 59 | { 60 | name: "explain", 61 | prefix: "explain select * from city", 62 | sqlstr: "", 63 | wantPrefix: "EXPLAIN", 64 | wantExecType: true, 65 | }, 66 | { 67 | name: "insert", 68 | prefix: "insert into city values (8181, 'Kabul', 'AFG', 'Kabol', 1780000);", 69 | sqlstr: "", 70 | wantPrefix: "INSERT", 71 | wantExecType: false, 72 | }, 73 | { 74 | name: "delete", 75 | prefix: "delete from city where id = 8181;", 76 | sqlstr: "", 77 | wantPrefix: "DELETE", 78 | wantExecType: false, 79 | }, 80 | } 81 | for _, tt := range tests { 82 | t.Run(tt.name, func(t *testing.T) { 83 | prefix, execType := QueryExecType(tt.prefix, tt.sqlstr) 84 | if prefix != tt.wantPrefix { 85 | t.Errorf("QueryExecType() got = %v, want %v", prefix, tt.wantPrefix) 86 | } 87 | if execType != tt.wantExecType { 88 | t.Errorf("QueryExecType() got1 = %v, want %v", execType, tt.wantExecType) 89 | } 90 | }) 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /internal/database/scan_row.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "database/sql" 5 | "encoding/json" 6 | "fmt" 7 | "strings" 8 | "time" 9 | ) 10 | 11 | func Columns(rows *sql.Rows) ([]string, error) { 12 | var cols []string 13 | var err error 14 | 15 | cols, err = rows.Columns() 16 | if err != nil { 17 | return nil, fmt.Errorf("cannot get query columns, %w", err) 18 | } 19 | 20 | for i, c := range cols { 21 | if strings.TrimSpace(c) == "" { 22 | cols[i] = fmt.Sprintf("col%d", i) 23 | } 24 | } 25 | 26 | return cols, nil 27 | } 28 | 29 | func ScanRows(rows *sql.Rows, columnLength int) ([][]string, error) { 30 | stringRows := [][]string{} 31 | for rows.Next() { 32 | // scan to []interface{} 33 | rowBuffer := make([]interface{}, columnLength) 34 | for i := range rowBuffer { 35 | rowBuffer[i] = new(interface{}) 36 | } 37 | if err := rows.Scan(rowBuffer...); err != nil { 38 | return nil, err 39 | } 40 | 41 | stringRow := make([]string, columnLength) 42 | for i, buf := range rowBuffer { 43 | val, err := sqlValToString(buf) 44 | if err != nil { 45 | return nil, err 46 | } 47 | stringRow[i] = val 48 | } 49 | stringRows = append(stringRows, stringRow) 50 | } 51 | return stringRows, nil 52 | } 53 | 54 | func sqlValToString(pointer interface{}) (string, error) { 55 | res := "" 56 | if pointer == nil { 57 | return res, nil 58 | } 59 | 60 | val := *pointer.(*interface{}) 61 | switch v := (val).(type) { 62 | case []byte: 63 | res = string(v) 64 | case string: 65 | res = v 66 | case time.Time: 67 | res = v.Format(time.RFC3339Nano) 68 | case fmt.Stringer: 69 | res = v.String() 70 | case map[string]interface{}: 71 | buf, err := json.Marshal(v) 72 | if err != nil { 73 | return "", err 74 | } 75 | res = string(buf) 76 | case []interface{}: 77 | buf, err := json.Marshal(v) 78 | if err != nil { 79 | return "", err 80 | } 81 | res = string(buf) 82 | default: 83 | res = fmt.Sprintf("%v", v) 84 | } 85 | 86 | return res, nil 87 | } 88 | -------------------------------------------------------------------------------- /internal/database/sqlite3.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "log" 8 | 9 | _ "github.com/mattn/go-sqlite3" 10 | "github.com/sqls-server/sqls/dialect" 11 | ) 12 | 13 | func init() { 14 | RegisterOpen(dialect.DatabaseDriverSQLite3, sqlite3Open) 15 | RegisterFactory(dialect.DatabaseDriverSQLite3, NewSQLite3DBRepository) 16 | } 17 | 18 | func sqlite3Open(connCfg *DBConfig) (*DBConnection, error) { 19 | conn, err := sql.Open("sqlite3", connCfg.DataSourceName) 20 | if err != nil { 21 | return nil, err 22 | } 23 | conn.SetMaxIdleConns(DefaultMaxIdleConns) 24 | conn.SetMaxOpenConns(DefaultMaxOpenConns) 25 | return &DBConnection{ 26 | Conn: conn, 27 | }, nil 28 | } 29 | 30 | type SQLite3DBRepository struct { 31 | Conn *sql.DB 32 | } 33 | 34 | func NewSQLite3DBRepository(conn *sql.DB) DBRepository { 35 | return &SQLite3DBRepository{Conn: conn} 36 | } 37 | 38 | func (db *SQLite3DBRepository) Driver() dialect.DatabaseDriver { 39 | return dialect.DatabaseDriverSQLite3 40 | } 41 | 42 | func (db *SQLite3DBRepository) CurrentDatabase(ctx context.Context) (string, error) { 43 | return "", nil 44 | } 45 | 46 | func (db *SQLite3DBRepository) Databases(ctx context.Context) ([]string, error) { 47 | return []string{}, nil 48 | } 49 | 50 | func (db *SQLite3DBRepository) CurrentSchema(ctx context.Context) (string, error) { 51 | return db.CurrentDatabase(ctx) 52 | } 53 | 54 | func (db *SQLite3DBRepository) Schemas(ctx context.Context) ([]string, error) { 55 | return db.Databases(ctx) 56 | } 57 | 58 | func (db *SQLite3DBRepository) SchemaTables(ctx context.Context) (map[string][]string, error) { 59 | tables, err := db.Tables(ctx) 60 | if err != nil { 61 | return nil, err 62 | } 63 | return map[string][]string{"": tables}, nil 64 | } 65 | 66 | func (db *SQLite3DBRepository) Tables(ctx context.Context) ([]string, error) { 67 | rows, err := db.Conn.QueryContext(ctx, ` 68 | SELECT 69 | name 70 | FROM 71 | sqlite_master 72 | WHERE 73 | type = 'table' 74 | ORDER BY 75 | name 76 | `) 77 | if err != nil { 78 | log.Fatal(err) 79 | } 80 | defer rows.Close() 81 | tables := []string{} 82 | for rows.Next() { 83 | var table string 84 | if err := rows.Scan(&table); err != nil { 85 | return nil, err 86 | } 87 | tables = append(tables, table) 88 | } 89 | return tables, nil 90 | } 91 | 92 | func (db *SQLite3DBRepository) describeTable(ctx context.Context, tableName string) ([]*ColumnDesc, error) { 93 | rows, err := db.Conn.QueryContext(ctx, fmt.Sprintf("PRAGMA table_info(%s);", tableName)) 94 | if err != nil { 95 | log.Fatal(err) 96 | } 97 | defer rows.Close() 98 | tableInfos := []*ColumnDesc{} 99 | for rows.Next() { 100 | var id int 101 | var nonnull int 102 | var tableInfo ColumnDesc 103 | err := rows.Scan( 104 | &id, 105 | &tableInfo.Name, 106 | &tableInfo.Type, 107 | &nonnull, 108 | &tableInfo.Default, 109 | &tableInfo.Key, 110 | ) 111 | if err != nil { 112 | return nil, err 113 | } 114 | tableInfo.Table = tableName 115 | if nonnull != 0 { 116 | tableInfo.Null = "NO" 117 | } else { 118 | tableInfo.Null = "YES" 119 | } 120 | tableInfos = append(tableInfos, &tableInfo) 121 | } 122 | return tableInfos, nil 123 | } 124 | 125 | func (db *SQLite3DBRepository) DescribeDatabaseTable(ctx context.Context) ([]*ColumnDesc, error) { 126 | tables, err := db.Tables(ctx) 127 | if err != nil { 128 | return nil, err 129 | } 130 | all := []*ColumnDesc{} 131 | for _, table := range tables { 132 | descs, err := db.describeTable(ctx, table) 133 | if err != nil { 134 | return nil, err 135 | } 136 | all = append(all, descs...) 137 | } 138 | return all, nil 139 | } 140 | 141 | func (db *SQLite3DBRepository) DescribeDatabaseTableBySchema(ctx context.Context, _ string) ([]*ColumnDesc, error) { 142 | return db.DescribeDatabaseTable(ctx) 143 | } 144 | 145 | func (db *SQLite3DBRepository) DescribeForeignKeysBySchema(ctx context.Context, schemaName string) ([]*ForeignKey, error) { 146 | rows, err := db.Conn.QueryContext( 147 | ctx, 148 | ` 149 | SELECT m.name || p."id", 150 | m.name, 151 | p."from", 152 | p."table", 153 | p."to" 154 | FROM sqlite_master m 155 | JOIN pragma_foreign_key_list(m.name) p ON m.name != p."table" 156 | WHERE m.type = 'table' 157 | ORDER BY 1, p."seq" 158 | `) 159 | if err != nil { 160 | log.Fatal(err) 161 | } 162 | defer func() { _ = rows.Close() }() 163 | return parseForeignKeys(rows, schemaName) 164 | } 165 | 166 | func (db *SQLite3DBRepository) Exec(ctx context.Context, query string) (sql.Result, error) { 167 | return db.Conn.ExecContext(ctx, query) 168 | } 169 | 170 | func (db *SQLite3DBRepository) Query(ctx context.Context, query string) (*sql.Rows, error) { 171 | return db.Conn.QueryContext(ctx, query) 172 | } 173 | -------------------------------------------------------------------------------- /internal/database/vertica.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "github.com/sqls-server/sqls/dialect" 8 | _ "github.com/vertica/vertica-sql-go" 9 | "log" 10 | "strconv" 11 | ) 12 | 13 | func init() { 14 | RegisterOpen("vertica", verticaOpen) 15 | RegisterFactory("vertica", NewVerticaDBRepository) 16 | } 17 | 18 | func verticaOpen(dbConnCfg *DBConfig) (*DBConnection, error) { 19 | var ( 20 | conn *sql.DB 21 | ) 22 | DSName, err := genVerticaConfig(dbConnCfg) 23 | if err != nil { 24 | return nil, err 25 | } 26 | 27 | conn, err = sql.Open("vertica", DSName) 28 | if err != nil { 29 | return nil, err 30 | } 31 | 32 | conn.SetMaxIdleConns(DefaultMaxIdleConns) 33 | conn.SetMaxOpenConns(DefaultMaxOpenConns) 34 | 35 | return &DBConnection{ 36 | Conn: conn, 37 | Driver: dialect.DatabaseDriverVertica, 38 | }, nil 39 | } 40 | 41 | func genVerticaConfig(connCfg *DBConfig) (string, error) { 42 | if connCfg.DataSourceName != "" { 43 | return connCfg.DataSourceName, nil 44 | } 45 | 46 | host, port := connCfg.Host, connCfg.Port 47 | if host == "" { 48 | host = "127.0.0.1" 49 | } 50 | if port == 0 { 51 | port = 5433 52 | } 53 | 54 | DSName := connCfg.User + "/" + connCfg.Passwd + "@" + host + ":" + strconv.Itoa(port) + "/" + connCfg.DBName 55 | return DSName, nil 56 | } 57 | 58 | type VerticaDBRepository struct { 59 | Conn *sql.DB 60 | } 61 | 62 | func NewVerticaDBRepository(conn *sql.DB) DBRepository { 63 | return &VerticaDBRepository{Conn: conn} 64 | } 65 | 66 | func (db *VerticaDBRepository) Driver() dialect.DatabaseDriver { 67 | return dialect.DatabaseDriverVertica 68 | } 69 | 70 | func (db *VerticaDBRepository) CurrentDatabase(ctx context.Context) (string, error) { 71 | row := db.Conn.QueryRowContext(ctx, "SELECT CURRENT_SCHEMA()") 72 | var database string 73 | if err := row.Scan(&database); err != nil { 74 | return "", err 75 | } 76 | return database, nil 77 | } 78 | 79 | func (db *VerticaDBRepository) Databases(ctx context.Context) ([]string, error) { 80 | rows, err := db.Conn.QueryContext(ctx, "SELECT schema_name FROM v_catalog.schemata") 81 | if err != nil { 82 | return nil, err 83 | } 84 | defer rows.Close() 85 | databases := []string{} 86 | for rows.Next() { 87 | var database string 88 | if err := rows.Scan(&database); err != nil { 89 | return nil, err 90 | } 91 | databases = append(databases, database) 92 | } 93 | return databases, nil 94 | } 95 | 96 | func (db *VerticaDBRepository) CurrentSchema(ctx context.Context) (string, error) { 97 | return db.CurrentDatabase(ctx) 98 | } 99 | 100 | func (db *VerticaDBRepository) Schemas(ctx context.Context) ([]string, error) { 101 | return db.Databases(ctx) 102 | } 103 | 104 | func (db *VerticaDBRepository) SchemaTables(ctx context.Context) (map[string][]string, error) { 105 | rows, err := db.Conn.QueryContext( 106 | ctx, 107 | ` 108 | SELECT schema_name, TABLE_NAME 109 | FROM v_catalog.all_tables 110 | ORDER BY schema_name, TABLE_NAME 111 | `) 112 | if err != nil { 113 | return nil, err 114 | } 115 | defer rows.Close() 116 | databaseTables := map[string][]string{} 117 | for rows.Next() { 118 | var schema, table string 119 | if err := rows.Scan(&schema, &table); err != nil { 120 | return nil, err 121 | } 122 | 123 | if arr, ok := databaseTables[schema]; ok { 124 | databaseTables[schema] = append(arr, table) 125 | } else { 126 | databaseTables[schema] = []string{table} 127 | } 128 | } 129 | return databaseTables, nil 130 | } 131 | 132 | func (db *VerticaDBRepository) Tables(ctx context.Context) ([]string, error) { 133 | rows, err := db.Conn.QueryContext(ctx, "SELECT table_name FROM v_catalog.tables ORDER BY 1") 134 | if err != nil { 135 | return nil, err 136 | } 137 | defer rows.Close() 138 | tables := []string{} 139 | for rows.Next() { 140 | var table string 141 | if err := rows.Scan(&table); err != nil { 142 | return nil, err 143 | } 144 | tables = append(tables, table) 145 | } 146 | return tables, nil 147 | } 148 | 149 | func (db *VerticaDBRepository) DescribeDatabaseTable(ctx context.Context) ([]*ColumnDesc, error) { 150 | rows, err := db.Conn.QueryContext( 151 | ctx, 152 | ` 153 | SELECT table_schema, 154 | table_name, 155 | column_name, 156 | data_type, 157 | is_nullable, 158 | '', 159 | column_default, 160 | '' 161 | FROM v_catalog.columns 162 | `) 163 | if err != nil { 164 | return nil, err 165 | } 166 | defer rows.Close() 167 | tableInfos := []*ColumnDesc{} 168 | for rows.Next() { 169 | var tableInfo ColumnDesc 170 | err := rows.Scan( 171 | &tableInfo.Schema, 172 | &tableInfo.Table, 173 | &tableInfo.Name, 174 | &tableInfo.Type, 175 | &tableInfo.Null, 176 | &tableInfo.Key, 177 | &tableInfo.Default, 178 | &tableInfo.Extra, 179 | ) 180 | if err != nil { 181 | return nil, err 182 | } 183 | tableInfos = append(tableInfos, &tableInfo) 184 | } 185 | return tableInfos, nil 186 | } 187 | 188 | func (db *VerticaDBRepository) DescribeDatabaseTableBySchema(ctx context.Context, schemaName string) ([]*ColumnDesc, error) { 189 | rows, err := db.Conn.QueryContext( 190 | ctx, 191 | ` 192 | SELECT table_schema, 193 | table_name, 194 | column_name, 195 | data_type, 196 | CASE is_nullable 197 | WHEN true THEN 'YES' 198 | ELSE 'NO' 199 | END AS is_nullable, 200 | '1' AS COLUMN_KEY, 201 | column_default, 202 | '1' AS EXTRA 203 | FROM v_catalog.columns 204 | WHERE table_schema = ? 205 | `, schemaName) 206 | if err != nil { 207 | log.Println("schema", schemaName, err.Error()) 208 | return nil, err 209 | } 210 | tableInfos := []*ColumnDesc{} 211 | for rows.Next() { 212 | var tableInfo ColumnDesc 213 | err := rows.Scan( 214 | &tableInfo.Schema, 215 | &tableInfo.Table, 216 | &tableInfo.Name, 217 | &tableInfo.Type, 218 | &tableInfo.Null, 219 | &tableInfo.Key, 220 | &tableInfo.Default, 221 | &tableInfo.Extra, 222 | ) 223 | if err != nil { 224 | return nil, err 225 | } 226 | tableInfos = append(tableInfos, &tableInfo) 227 | } 228 | defer rows.Close() 229 | return tableInfos, nil 230 | } 231 | 232 | func (db *VerticaDBRepository) Exec(ctx context.Context, query string) (sql.Result, error) { 233 | return db.Conn.ExecContext(ctx, query) 234 | } 235 | 236 | func (db *VerticaDBRepository) Query(ctx context.Context, query string) (*sql.Rows, error) { 237 | return db.Conn.QueryContext(ctx, query) 238 | } 239 | 240 | func (db *VerticaDBRepository) DescribeForeignKeysBySchema(ctx context.Context, schemaName string) ([]*ForeignKey, error) { 241 | return nil, fmt.Errorf("describe foreign keys is not supported") 242 | } 243 | -------------------------------------------------------------------------------- /internal/database/worker.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "sync" 7 | ) 8 | 9 | type Worker struct { 10 | dbRepo DBRepository 11 | dbCache *DBCache 12 | 13 | done chan struct{} 14 | update chan struct{} 15 | lock sync.Mutex 16 | } 17 | 18 | func NewWorker() *Worker { 19 | return &Worker{ 20 | done: make(chan struct{}, 1), 21 | update: make(chan struct{}, 1), 22 | } 23 | } 24 | 25 | func (w *Worker) Cache() *DBCache { 26 | return w.dbCache 27 | } 28 | 29 | func (w *Worker) setCache(c *DBCache) { 30 | w.lock.Lock() 31 | defer w.lock.Unlock() 32 | w.dbCache = c 33 | } 34 | 35 | func (w *Worker) setColumnCache(col map[string][]*ColumnDesc) { 36 | w.lock.Lock() 37 | defer w.lock.Unlock() 38 | if w.dbCache != nil { 39 | w.dbCache.ColumnsWithParent = col 40 | } 41 | } 42 | 43 | func (w *Worker) Start() { 44 | go func() { 45 | log.Println("db worker: start") 46 | for { 47 | select { 48 | case <-w.done: 49 | log.Println("db worker: done") 50 | return 51 | case <-w.update: 52 | generator := NewDBCacheUpdater(w.dbRepo) 53 | col, err := generator.GenerateDBCacheSecondary(context.Background()) 54 | if err != nil { 55 | log.Println(err) 56 | } 57 | w.setColumnCache(col) 58 | log.Println("db worker: Update db cache secondary complete") 59 | } 60 | } 61 | }() 62 | } 63 | 64 | func (w *Worker) Stop() { 65 | close(w.done) 66 | } 67 | 68 | func (w *Worker) ReCache(ctx context.Context, repo DBRepository) error { 69 | w.dbRepo = repo 70 | if err := w.updateAllCache(ctx); err != nil { 71 | return err 72 | } 73 | w.updateAdditionalCache() 74 | return nil 75 | } 76 | 77 | func (w *Worker) updateAllCache(ctx context.Context) error { 78 | generator := NewDBCacheUpdater(w.dbRepo) 79 | cache, err := generator.GenerateDBCachePrimary(ctx) 80 | if err != nil { 81 | return err 82 | } 83 | w.setCache(cache) 84 | log.Println("db worker: Update db cache primary complete") 85 | return nil 86 | } 87 | 88 | func (w *Worker) updateAdditionalCache() { 89 | w.update <- struct{}{} 90 | } 91 | -------------------------------------------------------------------------------- /internal/debug/debug.go: -------------------------------------------------------------------------------- 1 | package debug 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | ) 7 | 8 | func DPrintln(a ...interface{}) { 9 | fmt.Fprintln(os.Stderr, a...) 10 | } 11 | 12 | func DPrintf(format string, a ...interface{}) { 13 | fmt.Fprintf(os.Stderr, format, a...) 14 | } 15 | -------------------------------------------------------------------------------- /internal/formatter/formatter_test.go: -------------------------------------------------------------------------------- 1 | package formatter 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "runtime" 7 | "slices" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/sqls-server/sqls/ast" 12 | "github.com/sqls-server/sqls/internal/config" 13 | "github.com/sqls-server/sqls/internal/lsp" 14 | "github.com/sqls-server/sqls/parser" 15 | ) 16 | 17 | func TestEval(t *testing.T) { 18 | testcases := []struct { 19 | name string 20 | input string 21 | params lsp.DocumentFormattingParams 22 | config *config.Config 23 | expected string 24 | }{ 25 | { 26 | name: "InsertIntoFormat", 27 | input: "INSERT INTO users (NAME, email) VALUES ('john doe', 'example@host.com')", 28 | expected: "INSERT INTO users(\n\tNAME,\n\temail\n)\nVALUES(\n\t'john doe',\n\t'example@host.com'\n)", 29 | params: lsp.DocumentFormattingParams{}, 30 | config: &config.Config{ 31 | LowercaseKeywords: false, 32 | }, 33 | }, 34 | } 35 | 36 | for _, tt := range testcases { 37 | t.Run(tt.name, func(t *testing.T) { 38 | actual, _ := Format(tt.input, tt.params, tt.config) 39 | if actual[0].NewText != tt.expected { 40 | t.Errorf("expected: %s, got %s", tt.expected, actual[0].NewText) 41 | } 42 | }) 43 | } 44 | } 45 | 46 | func TestRenderIdentifier(t *testing.T) { 47 | testcases := []struct { 48 | name string 49 | input string 50 | opts *ast.RenderOptions 51 | expected []string 52 | }{ 53 | { 54 | name: "snake case", 55 | input: "SELECT * FROM snake_case_table_name", 56 | opts: &ast.RenderOptions{ 57 | LowerCase: false, 58 | IdentifierQuoted: false, 59 | }, 60 | expected: []string{ 61 | "*", 62 | "snake_case_table_name", 63 | }, 64 | }, 65 | { 66 | name: "pascal case", 67 | input: "SELECT p.PascalCaseColumnName FROM \"PascalCaseTableName\" p", 68 | opts: &ast.RenderOptions{ 69 | LowerCase: false, 70 | IdentifierQuoted: false, 71 | }, 72 | expected: []string{ 73 | "p.PascalCaseColumnName", 74 | "\"PascalCaseTableName\"", 75 | }, 76 | }, 77 | { 78 | name: "quoted pascal case", 79 | input: "SELECT p.\"PascalCaseColumnName\" FROM \"PascalCaseTableName\" p", 80 | opts: &ast.RenderOptions{ 81 | LowerCase: false, 82 | IdentifierQuoted: false, 83 | }, 84 | expected: []string{ 85 | "p.\"PascalCaseColumnName\"", 86 | "\"PascalCaseTableName\"", 87 | }, 88 | }, 89 | } 90 | 91 | for _, tt := range testcases { 92 | t.Run(tt.name, func(t *testing.T) { 93 | stmts := parseInit(t, tt.input) 94 | list := stmts[0].GetTokens() 95 | j := 0 96 | for _, n := range list { 97 | if i, ok := n.(*ast.Identifier); ok { 98 | if actual := i.Render(tt.opts); actual != tt.expected[j] { 99 | t.Errorf("expected: %s, got %s", tt.expected[j], actual) 100 | } 101 | j++ 102 | } 103 | } 104 | }) 105 | } 106 | } 107 | 108 | func parseInit(t *testing.T, input string) []*ast.Statement { 109 | t.Helper() 110 | parsed, err := parser.Parse(input) 111 | if err != nil { 112 | t.Fatalf("error %+v\n", err) 113 | } 114 | 115 | var stmts []*ast.Statement 116 | for _, node := range parsed.GetTokens() { 117 | stmt, ok := node.(*ast.Statement) 118 | if !ok { 119 | t.Fatalf("invalid type want Statement parsed %T", stmt) 120 | } 121 | stmts = append(stmts, stmt) 122 | } 123 | return stmts 124 | } 125 | 126 | func TestFormat(t *testing.T) { 127 | _, filename, _, _ := runtime.Caller(0) 128 | dir := filepath.Dir(filename) 129 | 130 | files, err := filepath.Glob(filepath.Join(dir, "testdata", "*.sql")) 131 | if err != nil { 132 | t.Fatal(err) 133 | } 134 | slices.Sort(files) 135 | 136 | opts := &ast.RenderOptions{ 137 | LowerCase: false, 138 | IdentifierQuoted: false, 139 | } 140 | for _, fname := range files { 141 | b, err := os.ReadFile(fname) 142 | if err != nil { 143 | t.Fatal(err) 144 | } 145 | parsed, err := parser.Parse(string(b)) 146 | if err != nil { 147 | t.Fatal(err) 148 | } 149 | env := &formatEnvironment{} 150 | formatted := Eval(parsed, env) 151 | got := strings.TrimRight(formatted.Render(opts), "\n") + "\n" 152 | 153 | b, err = os.ReadFile(fname[:len(fname)-4] + ".golden") 154 | if err != nil { 155 | t.Fatal(err) 156 | } 157 | want := string(b) 158 | if got != want { 159 | if _, err := os.Stat(fname[:len(fname)-4] + ".ignore"); err == nil { 160 | t.Logf("%s:\n"+ 161 | " want: %q\n"+ 162 | " got: %q\n", 163 | fname, want, got) 164 | } else { 165 | t.Errorf("%s:\n"+ 166 | " want: %q\n"+ 167 | " got: %q\n", 168 | fname, want, got) 169 | } 170 | } 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /internal/formatter/formatutil.go: -------------------------------------------------------------------------------- 1 | package formatter 2 | 3 | import ( 4 | "github.com/sqls-server/sqls/ast" 5 | "github.com/sqls-server/sqls/token" 6 | ) 7 | 8 | func unshift(slice []ast.Node, node ...ast.Node) []ast.Node { 9 | return append(node, slice...) 10 | } 11 | 12 | var whitespaceNode = ast.NewItem(&token.Token{ 13 | Kind: token.Whitespace, 14 | Value: " ", 15 | }) 16 | 17 | func whiteSpaceNodes(num int) []ast.Node { 18 | res := make([]ast.Node, num) 19 | for i := 0; i < num; i++ { 20 | res[i] = whitespaceNode 21 | } 22 | return res 23 | } 24 | 25 | var linebreakNode = ast.NewItem(&token.Token{ 26 | Kind: token.Whitespace, 27 | Value: "\n", 28 | }) 29 | 30 | var tabNode = ast.NewItem(&token.Token{ 31 | Kind: token.Whitespace, 32 | Value: "\t", 33 | }) 34 | 35 | var periodNode = ast.NewItem(&token.Token{ 36 | Kind: token.Period, 37 | Value: ".", 38 | }) 39 | 40 | var lparenNode = ast.NewItem(&token.Token{ 41 | Kind: token.LParen, 42 | Value: "(", 43 | }) 44 | 45 | var rparenNode = ast.NewItem(&token.Token{ 46 | Kind: token.RParen, 47 | Value: ")", 48 | }) 49 | 50 | var commaNode = ast.NewItem(&token.Token{ 51 | Kind: token.Comma, 52 | Value: ",", 53 | }) 54 | -------------------------------------------------------------------------------- /internal/formatter/testdata/001.golden: -------------------------------------------------------------------------------- 1 | SELECT 2 | * 3 | FROM 4 | TABLE 5 | -------------------------------------------------------------------------------- /internal/formatter/testdata/001.sql: -------------------------------------------------------------------------------- 1 | select * From table 2 | -------------------------------------------------------------------------------- /internal/formatter/testdata/002.golden: -------------------------------------------------------------------------------- 1 | -- hoge -- 2 | SELECT 3 | x/*x*/, 4 | /*x*/y 5 | FROM 6 | zzz; 7 | -- zzzz 8 | SELECT 9 | * 10 | FROM 11 | yyy; 12 | -- yyyy 13 | -- hage -- 14 | -------------------------------------------------------------------------------- /internal/formatter/testdata/002.sql: -------------------------------------------------------------------------------- 1 | -- hoge -- 2 | SELECT x/*x*/, /*x*/y FROM zzz; -- zzzz 3 | SELECT * FROM yyy; -- yyyy 4 | -- hage -- 5 | -------------------------------------------------------------------------------- /internal/formatter/testdata/003.golden: -------------------------------------------------------------------------------- 1 | -- hoge -- 2 | SELECT 3 | x/*x*/, 4 | /*x*/y 5 | FROM 6 | zzz; 7 | -- zzzz 8 | SELECT 9 | * 10 | FROM 11 | yyy; 12 | -- yyyy 13 | -- hage -- 14 | SELECT 15 | ap.autograph_purchase_id AS "id", 16 | ap.order_number AS "order", 17 | ap.product_price AS "productPrice", 18 | i.name AS "influencerName", 19 | p.name AS "productName", 20 | u.email AS "email" 21 | FROM 22 | autograph_purchaseASap innser 23 | JOIN influencer AS i 24 | ON autograph_purchase.influencer_id = influencer.influencer_id 25 | LEFT JOIN product AS p 26 | ON product.product_id = autograph_purchase.product_id 27 | LEFT JOIN USER AS u 28 | ON USER.user_id = autograph_purchase.user_id 29 | -------------------------------------------------------------------------------- /internal/formatter/testdata/003.ignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sqls-server/sqls/efe7f66d16e9479e242d3876c2a4a878ee190568/internal/formatter/testdata/003.ignore -------------------------------------------------------------------------------- /internal/formatter/testdata/003.sql: -------------------------------------------------------------------------------- 1 | SELECT 2 | ap.autograph_purchase_id AS "id", 3 | ap.order_number AS "order", 4 | ap.product_price AS "productPrice", 5 | i.name AS "influencerName", 6 | p.name AS "productName", 7 | u.email AS "email" 8 | FROM 9 | autograph_purchaseASap innser 10 | JOIN influencer AS i 11 | ON autograph_purchase.influencer_id = influencer.influencer_id 12 | LEFT JOIN product AS p 13 | ON product.product_id = autograph_purchase.product_id 14 | LEFT JOIN USER1 AS u 15 | ON USER1.user_id = autograph_purchase.user_id 16 | -------------------------------------------------------------------------------- /internal/handler/completion.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | 8 | "github.com/sourcegraph/jsonrpc2" 9 | "github.com/sqls-server/sqls/internal/completer" 10 | "github.com/sqls-server/sqls/internal/lsp" 11 | ) 12 | 13 | func (s *Server) handleTextDocumentCompletion(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) (result interface{}, err error) { 14 | if req.Params == nil { 15 | return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidParams} 16 | } 17 | 18 | var params lsp.CompletionParams 19 | if err := json.Unmarshal(*req.Params, ¶ms); err != nil { 20 | return nil, err 21 | } 22 | 23 | f, ok := s.files[params.TextDocument.URI] 24 | if !ok { 25 | return nil, fmt.Errorf("document not found: %s", params.TextDocument.URI) 26 | } 27 | 28 | c := completer.NewCompleter(s.worker.Cache()) 29 | if s.dbConn != nil { 30 | c.Driver = s.dbConn.Driver 31 | } else { 32 | c.Driver = "" 33 | } 34 | completionItems, err := c.Complete(f.Text, params, s.getConfig().LowercaseKeywords) 35 | if err != nil { 36 | return nil, err 37 | } 38 | return completionItems, nil 39 | } 40 | -------------------------------------------------------------------------------- /internal/handler/definition.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | 8 | "github.com/sourcegraph/jsonrpc2" 9 | "github.com/sqls-server/sqls/ast" 10 | "github.com/sqls-server/sqls/ast/astutil" 11 | "github.com/sqls-server/sqls/internal/database" 12 | "github.com/sqls-server/sqls/internal/lsp" 13 | "github.com/sqls-server/sqls/parser" 14 | "github.com/sqls-server/sqls/parser/parseutil" 15 | "github.com/sqls-server/sqls/token" 16 | ) 17 | 18 | func (s *Server) handleDefinition(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) (result interface{}, err error) { 19 | if req.Params == nil { 20 | return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidParams} 21 | } 22 | 23 | var params lsp.DefinitionParams 24 | if err := json.Unmarshal(*req.Params, ¶ms); err != nil { 25 | return nil, err 26 | } 27 | 28 | f, ok := s.files[params.TextDocument.URI] 29 | if !ok { 30 | return nil, fmt.Errorf("document not found: %s", params.TextDocument.URI) 31 | } 32 | 33 | return definition(params.TextDocument.URI, f.Text, params, s.worker.Cache()) 34 | } 35 | 36 | func definition(url, text string, params lsp.DefinitionParams, dbCache *database.DBCache) (lsp.Definition, error) { 37 | pos := token.Pos{ 38 | Line: params.Position.Line, 39 | Col: params.Position.Character + 1, 40 | } 41 | parsed, err := parser.Parse(text) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | nodeWalker := parseutil.NewNodeWalker(parsed, pos) 47 | m := astutil.NodeMatcher{ 48 | NodeTypes: []ast.NodeType{ast.TypeIdentifier}, 49 | } 50 | currentVariable := nodeWalker.CurNodeBottomMatched(m) 51 | if currentVariable == nil { 52 | return nil, nil 53 | } 54 | 55 | aliases := parseutil.ExtractAliased(parsed) 56 | if len(aliases) == 0 { 57 | return nil, nil 58 | } 59 | 60 | var define ast.Node 61 | for _, v := range aliases { 62 | alias, _ := v.(*ast.Aliased) 63 | if alias.AliasedName.String() == currentVariable.String() { 64 | define = alias.AliasedName 65 | break 66 | } 67 | } 68 | 69 | if define == nil { 70 | return nil, nil 71 | } 72 | 73 | res := []lsp.Location{ 74 | { 75 | URI: url, 76 | Range: lsp.Range{ 77 | Start: lsp.Position{ 78 | Line: define.Pos().Line, 79 | Character: define.Pos().Col, 80 | }, 81 | End: lsp.Position{ 82 | Line: define.End().Line, 83 | Character: define.End().Col, 84 | }, 85 | }, 86 | }, 87 | } 88 | 89 | return res, nil 90 | } 91 | -------------------------------------------------------------------------------- /internal/handler/definition_test.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/google/go-cmp/cmp" 7 | "github.com/sqls-server/sqls/internal/config" 8 | "github.com/sqls-server/sqls/internal/database" 9 | "github.com/sqls-server/sqls/internal/lsp" 10 | ) 11 | 12 | var definitionTestCases = []struct { 13 | name string 14 | input string 15 | pos lsp.Position 16 | want lsp.Definition 17 | }{ 18 | { 19 | name: "subquery", 20 | input: "SELECT it.ID, it.Name FROM (SELECT ci.ID, ci.Name, ci.CountryCode, ci.District, ci.Population FROM city AS ci) as it", 21 | pos: lsp.Position{ 22 | Line: 0, 23 | Character: 8, 24 | }, 25 | want: []lsp.Location{ 26 | { 27 | URI: testFileURI, 28 | Range: lsp.Range{ 29 | Start: lsp.Position{ 30 | Line: 0, 31 | Character: 114, 32 | }, 33 | End: lsp.Position{ 34 | Line: 0, 35 | Character: 116, 36 | }, 37 | }, 38 | }, 39 | }, 40 | }, 41 | { 42 | name: "inner subquery", 43 | input: "SELECT it.ID, it.Name FROM (SELECT ci.ID, ci.Name, ci.CountryCode, ci.District, ci.Population FROM city AS ci) as it", 44 | pos: lsp.Position{ 45 | Line: 0, 46 | Character: 36, 47 | }, 48 | want: []lsp.Location{ 49 | { 50 | URI: testFileURI, 51 | Range: lsp.Range{ 52 | Start: lsp.Position{ 53 | Line: 0, 54 | Character: 107, 55 | }, 56 | End: lsp.Position{ 57 | Line: 0, 58 | Character: 109, 59 | }, 60 | }, 61 | }, 62 | }, 63 | }, 64 | { 65 | name: "alias", 66 | input: "SELECT ci.ID, ci.Name FROM city as ci", 67 | pos: lsp.Position{ 68 | Line: 0, 69 | Character: 8, 70 | }, 71 | want: []lsp.Location{ 72 | { 73 | URI: testFileURI, 74 | Range: lsp.Range{ 75 | Start: lsp.Position{ 76 | Line: 0, 77 | Character: 35, 78 | }, 79 | End: lsp.Position{ 80 | Line: 0, 81 | Character: 37, 82 | }, 83 | }, 84 | }, 85 | }, 86 | }, 87 | } 88 | 89 | func TestDefinition(t *testing.T) { 90 | tx := newTestContext() 91 | tx.setup(t) 92 | defer tx.tearDown() 93 | 94 | cfg := &config.Config{ 95 | Connections: []*database.DBConfig{ 96 | {Driver: "mock"}, 97 | }, 98 | } 99 | tx.addWorkspaceConfig(t, cfg) 100 | 101 | for _, tt := range definitionTestCases { 102 | t.Run(tt.name, func(t *testing.T) { 103 | tx.textDocumentDidOpen(t, testFileURI, tt.input) 104 | 105 | params := lsp.DefinitionParams{ 106 | TextDocumentPositionParams: lsp.TextDocumentPositionParams{ 107 | TextDocument: lsp.TextDocumentIdentifier{ 108 | URI: testFileURI, 109 | }, 110 | Position: tt.pos, 111 | }, 112 | } 113 | var got lsp.Definition 114 | err := tx.conn.Call(tx.ctx, "textDocument/definition", params, &got) 115 | if err != nil { 116 | t.Errorf("conn.Call textDocument/definition: %+v", err) 117 | return 118 | } 119 | 120 | if diff := cmp.Diff(tt.want, got); diff != "" { 121 | t.Errorf("unmatch hover contents (- want, + got):\n%s", diff) 122 | } 123 | }) 124 | } 125 | } 126 | 127 | func TestTypeDefinition(t *testing.T) { 128 | tx := newTestContext() 129 | tx.setup(t) 130 | defer tx.tearDown() 131 | 132 | cfg := &config.Config{ 133 | Connections: []*database.DBConfig{ 134 | {Driver: "mock"}, 135 | }, 136 | } 137 | tx.addWorkspaceConfig(t, cfg) 138 | 139 | for _, tt := range definitionTestCases { 140 | t.Run(tt.name, func(t *testing.T) { 141 | tx.textDocumentDidOpen(t, testFileURI, tt.input) 142 | 143 | params := lsp.DefinitionParams{ 144 | TextDocumentPositionParams: lsp.TextDocumentPositionParams{ 145 | TextDocument: lsp.TextDocumentIdentifier{ 146 | URI: testFileURI, 147 | }, 148 | Position: tt.pos, 149 | }, 150 | } 151 | var got lsp.Definition 152 | err := tx.conn.Call(tx.ctx, "textDocument/typeDefinition", params, &got) 153 | if err != nil { 154 | t.Errorf("conn.Call textDocument/definition: %+v", err) 155 | return 156 | } 157 | 158 | if diff := cmp.Diff(tt.want, got); diff != "" { 159 | t.Errorf("unmatch hover contents (- want, + got):\n%s", diff) 160 | } 161 | }) 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /internal/handler/execute_command_test.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/sqls-server/sqls/internal/config" 7 | "github.com/sqls-server/sqls/internal/database" 8 | "github.com/sqls-server/sqls/internal/lsp" 9 | ) 10 | 11 | func Test_executeQuery(t *testing.T) { 12 | tx := newTestContext() 13 | tx.setup(t) 14 | defer tx.tearDown() 15 | 16 | didChangeConfigurationParams := lsp.DidChangeConfigurationParams{ 17 | Settings: struct { 18 | SQLS *config.Config "json:\"sqls\"" 19 | }{ 20 | SQLS: &config.Config{ 21 | Connections: []*database.DBConfig{ 22 | { 23 | Driver: "mock", 24 | DataSourceName: "", 25 | }, 26 | }, 27 | }, 28 | }, 29 | } 30 | if err := tx.conn.Call(tx.ctx, "workspace/didChangeConfiguration", didChangeConfigurationParams, nil); err != nil { 31 | t.Fatal("conn.Call workspace/didChangeConfiguration:", err) 32 | } 33 | 34 | uri := "file:///test.sql" 35 | text := "SELECT 1; SELECT 2;" 36 | didOpenParams := lsp.DidOpenTextDocumentParams{ 37 | TextDocument: lsp.TextDocumentItem{ 38 | URI: uri, 39 | LanguageID: "sql", 40 | Version: 0, 41 | Text: text, 42 | }, 43 | } 44 | if err := tx.conn.Call(tx.ctx, "textDocument/didOpen", didOpenParams, nil); err != nil { 45 | t.Fatal("conn.Call textDocument/didOpen:", err) 46 | } 47 | tx.testFile(t, didOpenParams.TextDocument.URI, didOpenParams.TextDocument.Text) 48 | 49 | // executeCommandParams := lsp.ExecuteCommandParams{ 50 | // Command: CommandExecuteQuery, 51 | // Arguments: []interface{}{uri}, 52 | // } 53 | // var got interface{} 54 | // tx.conn.Call(tx.ctx, "workspace/executeCommand", executeCommandParams, &got) 55 | // pass error 56 | } 57 | 58 | func Test_extractRangeText(t *testing.T) { 59 | type args struct { 60 | text string 61 | startLine int 62 | startChar int 63 | endLine int 64 | endChar int 65 | } 66 | tests := []struct { 67 | name string 68 | args args 69 | want string 70 | }{ 71 | { 72 | name: "extract single line", 73 | args: args{ 74 | text: "select * from city", 75 | startLine: 0, 76 | startChar: 0, 77 | endLine: 0, 78 | endChar: 8, 79 | }, 80 | want: "select *", 81 | }, 82 | { 83 | name: "extract multi line with not equal start end", 84 | args: args{ 85 | text: "select 1;\nselect 2;\nselect 3;", 86 | startLine: 0, 87 | startChar: 7, 88 | endLine: 2, 89 | endChar: 8, 90 | }, 91 | want: "1;\nselect 2;\nselect 3", 92 | }, 93 | { 94 | name: "extract multi line with equal start end", 95 | args: args{ 96 | text: "select 1;\nselect 2;\nselect 3;", 97 | startLine: 1, 98 | startChar: 2, 99 | endLine: 1, 100 | endChar: 6, 101 | }, 102 | want: "lect", 103 | }, 104 | } 105 | for _, tt := range tests { 106 | t.Run(tt.name, func(t *testing.T) { 107 | if got := extractRangeText(tt.args.text, tt.args.startLine, tt.args.startChar, tt.args.endLine, tt.args.endChar); got != tt.want { 108 | t.Errorf("got %q, want %q", got, tt.want) 109 | } 110 | }) 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /internal/handler/format.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | 8 | "github.com/sourcegraph/jsonrpc2" 9 | "github.com/sqls-server/sqls/internal/formatter" 10 | "github.com/sqls-server/sqls/internal/lsp" 11 | ) 12 | 13 | func (s *Server) handleTextDocumentFormatting(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) (result interface{}, err error) { 14 | if req.Params == nil { 15 | return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidParams} 16 | } 17 | 18 | var params lsp.DocumentFormattingParams 19 | if err := json.Unmarshal(*req.Params, ¶ms); err != nil { 20 | return nil, err 21 | } 22 | 23 | f, ok := s.files[params.TextDocument.URI] 24 | if !ok { 25 | return nil, fmt.Errorf("document not found: %s", params.TextDocument.URI) 26 | } 27 | 28 | textEdits, err := formatter.Format(f.Text, params, s.getConfig()) 29 | if err != nil { 30 | return nil, err 31 | } 32 | if len(textEdits) > 0 { 33 | return textEdits, nil 34 | } 35 | return nil, nil 36 | } 37 | 38 | func (s *Server) handleTextDocumentRangeFormatting(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) (result interface{}, err error) { 39 | if req.Params == nil { 40 | return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidParams} 41 | } 42 | 43 | var params lsp.DocumentRangeFormattingParams 44 | if err := json.Unmarshal(*req.Params, ¶ms); err != nil { 45 | return nil, err 46 | } 47 | 48 | _, ok := s.files[params.TextDocument.URI] 49 | if !ok { 50 | return nil, fmt.Errorf("document not found: %s", params.TextDocument.URI) 51 | } 52 | 53 | textEdits := []lsp.TextEdit{} 54 | if len(textEdits) > 0 { 55 | return textEdits, nil 56 | } 57 | return nil, nil 58 | } 59 | -------------------------------------------------------------------------------- /internal/handler/format_test.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/google/go-cmp/cmp" 11 | "github.com/sqls-server/sqls/internal/config" 12 | "github.com/sqls-server/sqls/internal/lsp" 13 | ) 14 | 15 | var formattingOptionTab = lsp.FormattingOptions{ 16 | TabSize: 0.0, 17 | InsertSpaces: false, 18 | TrimTrailingWhitespace: false, 19 | InsertFinalNewline: false, 20 | TrimFinalNewlines: false, 21 | } 22 | 23 | var formattingOptionIndentSpace2 = lsp.FormattingOptions{ 24 | TabSize: 2.0, 25 | InsertSpaces: true, 26 | TrimTrailingWhitespace: false, 27 | InsertFinalNewline: false, 28 | TrimFinalNewlines: false, 29 | } 30 | 31 | var formattingOptionIndentSpace4 = lsp.FormattingOptions{ 32 | TabSize: 4.0, 33 | InsertSpaces: true, 34 | TrimTrailingWhitespace: false, 35 | InsertFinalNewline: false, 36 | TrimFinalNewlines: false, 37 | } 38 | 39 | var upperCaseConfig = &config.Config{ 40 | LowercaseKeywords: false, 41 | } 42 | 43 | var lowerCaseConfig = &config.Config{ 44 | LowercaseKeywords: true, 45 | } 46 | 47 | type formattingTestCase struct { 48 | name string 49 | input string 50 | want string 51 | } 52 | 53 | func testFormatting(t *testing.T, testCases []formattingTestCase, options lsp.FormattingOptions, cfg *config.Config) { 54 | tx := newTestContext() 55 | tx.initServer(t) 56 | defer tx.tearDown() 57 | 58 | didChangeConfigurationParams := lsp.DidChangeConfigurationParams{ 59 | Settings: struct { 60 | SQLS *config.Config "json:\"sqls\"" 61 | }{ 62 | SQLS: cfg, 63 | }, 64 | } 65 | 66 | if err := tx.conn.Call(tx.ctx, "workspace/didChangeConfiguration", didChangeConfigurationParams, nil); err != nil { 67 | t.Fatal("conn.Call workspace/didChangeConfiguration:", err) 68 | } 69 | uri := "file:///Users/octref/Code/css-test/test.sql" 70 | for _, tt := range testCases { 71 | t.Run(tt.name, func(t *testing.T) { 72 | // Open dummy file 73 | didOpenParams := lsp.DidOpenTextDocumentParams{ 74 | TextDocument: lsp.TextDocumentItem{ 75 | URI: uri, 76 | LanguageID: "sql", 77 | Version: 0, 78 | Text: tt.input, 79 | }, 80 | } 81 | if err := tx.conn.Call(tx.ctx, "textDocument/didOpen", didOpenParams, nil); err != nil { 82 | t.Fatal("conn.Call textDocument/didOpen:", err) 83 | } 84 | tx.testFile(t, didOpenParams.TextDocument.URI, didOpenParams.TextDocument.Text) 85 | // Create completion params 86 | formattingParams := lsp.DocumentFormattingParams{ 87 | TextDocument: lsp.TextDocumentIdentifier{ 88 | URI: uri, 89 | }, 90 | Options: options, 91 | WorkDoneProgressParams: lsp.WorkDoneProgressParams{ 92 | WorkDoneToken: nil, 93 | }, 94 | } 95 | 96 | var got []lsp.TextEdit 97 | if err := tx.conn.Call(tx.ctx, "textDocument/formatting", formattingParams, &got); err != nil { 98 | t.Fatal("conn.Call textDocument/formatting:", err) 99 | } 100 | if diff := cmp.Diff(tt.want, got[0].NewText); diff != "" { 101 | t.Errorf("unmatch (- want, + got):\n%s", diff) 102 | t.Errorf("unmatch\nwant: %q\ngot : %q", tt.want, got[0].NewText) 103 | } 104 | }) 105 | } 106 | } 107 | 108 | func TestFormattingBase(t *testing.T) { 109 | testCase, err := loadFormatTestCaseByTestdata("format") 110 | if err != nil { 111 | t.Fatal(err) 112 | } 113 | testFormatting(t, testCase, formattingOptionTab, lowerCaseConfig) 114 | } 115 | 116 | func TestFormattingMinimal(t *testing.T) { 117 | // Add minimal case test 118 | minimalTestCase := []formattingTestCase{ 119 | { 120 | name: "multi keyword", 121 | input: " inner join ", 122 | want: "inner join", 123 | }, 124 | { 125 | name: "aliased", 126 | input: "foo as f", 127 | want: "foo as f", 128 | }, 129 | { 130 | name: "member identifier", 131 | input: "foo.id", 132 | want: "foo.id", 133 | }, 134 | { 135 | name: "operator", 136 | input: "1+ 2 - 3 * 4", 137 | want: "1 + 2 - 3 * 4", 138 | }, 139 | { 140 | name: "comparison", 141 | input: "1 < 2", 142 | want: "1 < 2", 143 | }, 144 | // { 145 | // name: "parenthesis", 146 | // input: "( 1 + 2 ) = 3", 147 | // want: "(1 + 2) = 3", 148 | // }, 149 | { 150 | name: "identifier list", 151 | input: "1 , 2 , 3 , 4", 152 | want: "1,\n2,\n3,\n4", 153 | }, 154 | } 155 | testFormatting(t, minimalTestCase, formattingOptionTab, lowerCaseConfig) 156 | } 157 | 158 | func TestFormattingWithOptionSpace2(t *testing.T) { 159 | testCase, err := loadFormatTestCaseByTestdata("format_option_space2") 160 | if err != nil { 161 | t.Fatal(err) 162 | } 163 | testFormatting(t, testCase, formattingOptionIndentSpace2, lowerCaseConfig) 164 | } 165 | 166 | func TestFormattingWithOptionSpace4(t *testing.T) { 167 | testCase, err := loadFormatTestCaseByTestdata("format_option_space4") 168 | if err != nil { 169 | t.Fatal(err) 170 | } 171 | testFormatting(t, testCase, formattingOptionIndentSpace4, lowerCaseConfig) 172 | } 173 | 174 | func TestFormattingWithOptionUpper(t *testing.T) { 175 | testCase, err := loadFormatTestCaseByTestdata("upper_case") 176 | if err != nil { 177 | t.Fatal(err) 178 | } 179 | testFormatting(t, testCase, formattingOptionTab, upperCaseConfig) 180 | } 181 | 182 | func loadFormatTestCaseByTestdata(targetDir string) ([]formattingTestCase, error) { 183 | packageDir, err := os.Getwd() 184 | if err != nil { 185 | return nil, err 186 | } 187 | testDir := filepath.Join(packageDir, "testdata", targetDir) 188 | testFileInfos, err := os.ReadDir(testDir) 189 | if err != nil { 190 | return nil, err 191 | } 192 | 193 | testCase := []formattingTestCase{} 194 | const ( 195 | inputFileSuffix = ".input.sql" 196 | goldenFileSuffix = ".golden.sql" 197 | ) 198 | 199 | for _, testFileInfo := range testFileInfos { 200 | inputFileName := testFileInfo.Name() 201 | if !strings.HasSuffix(inputFileName, inputFileSuffix) { 202 | continue 203 | } 204 | 205 | testName := testFileInfo.Name()[:len(inputFileName)-len(inputFileSuffix)] 206 | inputPath := filepath.Join(testDir, inputFileName) 207 | goldenPath := filepath.Join(testDir, testName+goldenFileSuffix) 208 | 209 | input, err := os.ReadFile(inputPath) 210 | if err != nil { 211 | return nil, fmt.Errorf("Cannot read input file, Path=%s, Err=%+v", inputPath, err) 212 | } 213 | golden, err := os.ReadFile(goldenPath) 214 | if err != nil { 215 | return nil, fmt.Errorf("Cannot read input file, Path=%s, Err=%+v", goldenPath, err) 216 | } 217 | testCase = append(testCase, formattingTestCase{ 218 | name: testName, 219 | input: string(input), 220 | want: string(golden)[:len(string(golden))-len("\n")], 221 | }) 222 | } 223 | return testCase, nil 224 | } 225 | -------------------------------------------------------------------------------- /internal/handler/handler_test.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "log" 7 | "net" 8 | "reflect" 9 | "testing" 10 | 11 | "github.com/sourcegraph/jsonrpc2" 12 | 13 | "github.com/sqls-server/sqls/internal/config" 14 | "github.com/sqls-server/sqls/internal/lsp" 15 | ) 16 | 17 | const testFileURI = "file:///Users/octref/Code/css-test/test.sql" 18 | 19 | type TestContext struct { 20 | h jsonrpc2.Handler 21 | conn *jsonrpc2.Conn 22 | connServer *jsonrpc2.Conn 23 | server *Server 24 | ctx context.Context 25 | } 26 | 27 | func newTestContext() *TestContext { 28 | server := NewServer() 29 | handler := jsonrpc2.HandlerWithError(server.Handle) 30 | ctx := context.Background() 31 | return &TestContext{ 32 | h: handler, 33 | ctx: ctx, 34 | server: server, 35 | } 36 | } 37 | 38 | func (tx *TestContext) setup(t *testing.T) { 39 | t.Helper() 40 | tx.initServer(t) 41 | } 42 | 43 | func (tx *TestContext) tearDown() { 44 | if tx.conn != nil { 45 | if err := tx.conn.Close(); err != nil { 46 | log.Fatal("conn.Close:", err) 47 | } 48 | } 49 | 50 | if tx.connServer != nil { 51 | if err := tx.connServer.Close(); err != nil { 52 | if !errors.Is(err, jsonrpc2.ErrClosed) { 53 | log.Fatal("connServer.Close:", err) 54 | } 55 | } 56 | } 57 | } 58 | 59 | func (tx *TestContext) initServer(t *testing.T) { 60 | t.Helper() 61 | 62 | // Prepare the server and client connection. 63 | client, server := net.Pipe() 64 | tx.connServer = jsonrpc2.NewConn(tx.ctx, jsonrpc2.NewBufferedStream(server, jsonrpc2.VSCodeObjectCodec{}), tx.h) 65 | tx.conn = jsonrpc2.NewConn(tx.ctx, jsonrpc2.NewBufferedStream(client, jsonrpc2.VSCodeObjectCodec{}), tx.h) 66 | 67 | // Initialize Language Server 68 | params := lsp.InitializeParams{ 69 | InitializationOptions: lsp.InitializeOptions{}, 70 | } 71 | if err := tx.conn.Call(tx.ctx, "initialize", params, nil); err != nil { 72 | t.Fatal("conn.Call initialize:", err) 73 | } 74 | } 75 | 76 | func (tx *TestContext) addWorkspaceConfig(t *testing.T, cfg *config.Config) { 77 | didChangeConfigurationParams := lsp.DidChangeConfigurationParams{ 78 | Settings: struct { 79 | SQLS *config.Config "json:\"sqls\"" 80 | }{ 81 | SQLS: cfg, 82 | }, 83 | } 84 | if err := tx.conn.Call(tx.ctx, "workspace/didChangeConfiguration", didChangeConfigurationParams, nil); err != nil { 85 | t.Fatal("conn.Call workspace/didChangeConfiguration:", err) 86 | } 87 | } 88 | 89 | func (tx *TestContext) textDocumentDidOpen(t *testing.T, uri, input string) { 90 | didOpenParams := lsp.DidOpenTextDocumentParams{ 91 | TextDocument: lsp.TextDocumentItem{ 92 | URI: testFileURI, 93 | LanguageID: "sql", 94 | Version: 0, 95 | Text: input, 96 | }, 97 | } 98 | if err := tx.conn.Call(tx.ctx, "textDocument/didOpen", didOpenParams, nil); err != nil { 99 | t.Fatal("conn.Call textDocument/didOpen:", err) 100 | } 101 | tx.testFile(t, didOpenParams.TextDocument.URI, didOpenParams.TextDocument.Text) 102 | } 103 | 104 | func TestInitialized(t *testing.T) { 105 | tx := newTestContext() 106 | tx.setup(t) 107 | defer tx.tearDown() 108 | 109 | want := lsp.InitializeResult{ 110 | Capabilities: lsp.ServerCapabilities{ 111 | TextDocumentSync: lsp.TDSKFull, 112 | HoverProvider: true, 113 | CompletionProvider: &lsp.CompletionOptions{ 114 | TriggerCharacters: []string{"(", "."}, 115 | }, 116 | SignatureHelpProvider: &lsp.SignatureHelpOptions{ 117 | TriggerCharacters: []string{"(", ","}, 118 | RetriggerCharacters: []string{"(", ","}, 119 | WorkDoneProgressOptions: lsp.WorkDoneProgressOptions{ 120 | WorkDoneProgress: false, 121 | }, 122 | }, 123 | CodeActionProvider: true, 124 | DefinitionProvider: true, 125 | DocumentFormattingProvider: true, 126 | DocumentRangeFormattingProvider: true, 127 | RenameProvider: true, 128 | }, 129 | } 130 | var got lsp.InitializeResult 131 | params := lsp.InitializeParams{ 132 | InitializationOptions: lsp.InitializeOptions{}, 133 | } 134 | if err := tx.conn.Call(tx.ctx, "initialize", params, &got); err != nil { 135 | t.Fatal("conn.Call initialize:", err) 136 | } 137 | if !reflect.DeepEqual(want, got) { 138 | t.Errorf("not match \n%+v\n%+v", want, got) 139 | } 140 | } 141 | 142 | func TestFileWatch(t *testing.T) { 143 | tx := newTestContext() 144 | tx.setup(t) 145 | defer tx.tearDown() 146 | 147 | uri := "file:///Users/octref/Code/css-test/test.sql" 148 | openText := "SELECT * FROM todo ORDER BY id ASC" 149 | changeText := "SELECT * FROM todo ORDER BY name ASC" 150 | 151 | didOpenParams := lsp.DidOpenTextDocumentParams{ 152 | TextDocument: lsp.TextDocumentItem{ 153 | URI: uri, 154 | LanguageID: "sql", 155 | Version: 0, 156 | Text: openText, 157 | }, 158 | } 159 | if err := tx.conn.Call(tx.ctx, "textDocument/didOpen", didOpenParams, nil); err != nil { 160 | t.Fatal("conn.Call textDocument/didOpen:", err) 161 | } 162 | tx.testFile(t, didOpenParams.TextDocument.URI, didOpenParams.TextDocument.Text) 163 | 164 | didChangeParams := lsp.DidChangeTextDocumentParams{ 165 | TextDocument: lsp.VersionedTextDocumentIdentifier{ 166 | URI: uri, 167 | Version: 1, 168 | }, 169 | ContentChanges: []lsp.TextDocumentContentChangeEvent{ 170 | lsp.TextDocumentContentChangeEvent{ 171 | Range: lsp.Range{ 172 | Start: lsp.Position{ 173 | Line: 1, 174 | Character: 1, 175 | }, 176 | End: lsp.Position{ 177 | Line: 1, 178 | Character: 1, 179 | }, 180 | }, 181 | RangeLength: 1, 182 | Text: changeText, 183 | }, 184 | }, 185 | } 186 | if err := tx.conn.Call(tx.ctx, "textDocument/didChange", didChangeParams, nil); err != nil { 187 | t.Fatal("conn.Call textDocument/didChange:", err) 188 | } 189 | tx.testFile(t, didChangeParams.TextDocument.URI, didChangeParams.ContentChanges[0].Text) 190 | 191 | didSaveParams := lsp.DidSaveTextDocumentParams{ 192 | Text: openText, 193 | TextDocument: lsp.TextDocumentIdentifier{URI: uri}, 194 | } 195 | if err := tx.conn.Call(tx.ctx, "textDocument/didSave", didSaveParams, nil); err != nil { 196 | t.Fatal("conn.Call textDocument/didSave:", err) 197 | } 198 | tx.testFile(t, didSaveParams.TextDocument.URI, didSaveParams.Text) 199 | 200 | didCloseParams := lsp.DidCloseTextDocumentParams{ 201 | TextDocument: lsp.TextDocumentIdentifier{URI: uri}, 202 | } 203 | if err := tx.conn.Call(tx.ctx, "textDocument/didClose", didCloseParams, nil); err != nil { 204 | t.Fatal("conn.Call textDocument/didClose:", err) 205 | } 206 | _, ok := tx.server.files[didCloseParams.TextDocument.URI] 207 | if ok { 208 | t.Errorf("found opened file. URI:%s", didCloseParams.TextDocument.URI) 209 | } 210 | } 211 | 212 | func (tx *TestContext) testFile(t *testing.T, uri, text string) { 213 | f, ok := tx.server.files[uri] 214 | if !ok { 215 | t.Errorf("not found opened file. URI:%s", uri) 216 | } 217 | if f.Text != text { 218 | t.Errorf("not match %s. got: %s", text, f.Text) 219 | } 220 | } 221 | -------------------------------------------------------------------------------- /internal/handler/rename.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | 8 | "github.com/sourcegraph/jsonrpc2" 9 | "github.com/sqls-server/sqls/ast" 10 | "github.com/sqls-server/sqls/ast/astutil" 11 | "github.com/sqls-server/sqls/internal/lsp" 12 | "github.com/sqls-server/sqls/parser" 13 | "github.com/sqls-server/sqls/parser/parseutil" 14 | "github.com/sqls-server/sqls/token" 15 | ) 16 | 17 | func (s *Server) handleTextDocumentRename(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) (result interface{}, err error) { 18 | if req.Params == nil { 19 | return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidParams} 20 | } 21 | 22 | var params lsp.RenameParams 23 | if err := json.Unmarshal(*req.Params, ¶ms); err != nil { 24 | return nil, err 25 | } 26 | 27 | f, ok := s.files[params.TextDocument.URI] 28 | if !ok { 29 | return nil, fmt.Errorf("document not found: %s", params.TextDocument.URI) 30 | } 31 | 32 | res, err := rename(f.Text, params) 33 | if err != nil { 34 | return nil, err 35 | } 36 | return res, nil 37 | } 38 | 39 | func rename(text string, params lsp.RenameParams) (*lsp.WorkspaceEdit, error) { 40 | parsed, err := parser.Parse(text) 41 | if err != nil { 42 | return nil, err 43 | } 44 | 45 | pos := token.Pos{ 46 | Line: params.Position.Line, 47 | Col: params.Position.Character, 48 | } 49 | 50 | // Get the identifier on focus 51 | nodeWalker := parseutil.NewNodeWalker(parsed, pos) 52 | m := astutil.NodeMatcher{ 53 | NodeTypes: []ast.NodeType{ast.TypeIdentifier}, 54 | } 55 | currentVariable := nodeWalker.CurNodeBottomMatched(m) 56 | if currentVariable == nil { 57 | return nil, nil 58 | } 59 | 60 | // Get all identifiers in the statement 61 | idents, err := parseutil.ExtractIdenfiers(parsed, pos) 62 | if err != nil { 63 | return nil, err 64 | } 65 | 66 | // Extract only those with matching names 67 | renameTarget := []ast.Node{} 68 | for _, ident := range idents { 69 | if ident.String() == currentVariable.String() { 70 | renameTarget = append(renameTarget, ident) 71 | } 72 | } 73 | if len(renameTarget) == 0 { 74 | return nil, nil 75 | } 76 | 77 | edits := make([]lsp.TextEdit, len(renameTarget)) 78 | for i, target := range renameTarget { 79 | edit := lsp.TextEdit{ 80 | Range: lsp.Range{ 81 | Start: lsp.Position{ 82 | Line: target.Pos().Line, 83 | Character: target.Pos().Col, 84 | }, 85 | End: lsp.Position{ 86 | Line: target.End().Line, 87 | Character: target.End().Col, 88 | }, 89 | }, 90 | NewText: params.NewName, 91 | } 92 | edits[i] = edit 93 | } 94 | 95 | res := &lsp.WorkspaceEdit{ 96 | DocumentChanges: []lsp.TextDocumentEdit{ 97 | { 98 | TextDocument: lsp.OptionalVersionedTextDocumentIdentifier{ 99 | Version: 0, 100 | TextDocumentIdentifier: lsp.TextDocumentIdentifier{ 101 | URI: params.TextDocument.URI, 102 | }, 103 | }, 104 | Edits: edits, 105 | }, 106 | }, 107 | } 108 | 109 | return res, nil 110 | } 111 | -------------------------------------------------------------------------------- /internal/handler/rename_test.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/google/go-cmp/cmp" 7 | "github.com/sqls-server/sqls/internal/config" 8 | "github.com/sqls-server/sqls/internal/database" 9 | "github.com/sqls-server/sqls/internal/lsp" 10 | ) 11 | 12 | var renameTestCases = []struct { 13 | name string 14 | input string 15 | newName string 16 | output lsp.WorkspaceEdit 17 | pos lsp.Position 18 | }{ 19 | { 20 | name: "subquery", 21 | input: "SELECT it.ID, it.Name FROM (SELECT ci.ID, ci.Name, ci.CountryCode, ci.District, ci.Population FROM city AS ci) as it", 22 | newName: "ct", 23 | output: lsp.WorkspaceEdit{ 24 | DocumentChanges: []lsp.TextDocumentEdit{ 25 | { 26 | TextDocument: lsp.OptionalVersionedTextDocumentIdentifier{ 27 | Version: 0, 28 | TextDocumentIdentifier: lsp.TextDocumentIdentifier{ 29 | URI: "file:///Users/octref/Code/css-test/test.sql", 30 | }, 31 | }, 32 | Edits: []lsp.TextEdit{ 33 | { 34 | Range: lsp.Range{ 35 | Start: lsp.Position{ 36 | Line: 0, 37 | Character: 7, 38 | }, 39 | End: lsp.Position{ 40 | Line: 0, 41 | Character: 9, 42 | }, 43 | }, 44 | NewText: "ct", 45 | }, 46 | { 47 | Range: lsp.Range{ 48 | Start: lsp.Position{ 49 | Line: 0, 50 | Character: 14, 51 | }, 52 | End: lsp.Position{ 53 | Line: 0, 54 | Character: 16, 55 | }, 56 | }, 57 | NewText: "ct", 58 | }, 59 | { 60 | Range: lsp.Range{ 61 | Start: lsp.Position{ 62 | Line: 0, 63 | Character: 114, 64 | }, 65 | End: lsp.Position{ 66 | Line: 0, 67 | Character: 116, 68 | }, 69 | }, 70 | NewText: "ct", 71 | }, 72 | }, 73 | }, 74 | }, 75 | }, 76 | pos: lsp.Position{ 77 | Line: 0, 78 | Character: 8, 79 | }, 80 | }, 81 | { 82 | name: "ok", 83 | input: "SELECT ci.ID, ci.Name FROM city as ci", 84 | newName: "ct", 85 | output: lsp.WorkspaceEdit{ 86 | DocumentChanges: []lsp.TextDocumentEdit{ 87 | { 88 | TextDocument: lsp.OptionalVersionedTextDocumentIdentifier{ 89 | Version: 0, 90 | TextDocumentIdentifier: lsp.TextDocumentIdentifier{ 91 | URI: "file:///Users/octref/Code/css-test/test.sql", 92 | }, 93 | }, 94 | Edits: []lsp.TextEdit{ 95 | { 96 | Range: lsp.Range{ 97 | Start: lsp.Position{ 98 | Line: 0, 99 | Character: 7, 100 | }, 101 | End: lsp.Position{ 102 | Line: 0, 103 | Character: 9, 104 | }, 105 | }, 106 | NewText: "ct", 107 | }, 108 | { 109 | Range: lsp.Range{ 110 | Start: lsp.Position{ 111 | Line: 0, 112 | Character: 14, 113 | }, 114 | End: lsp.Position{ 115 | Line: 0, 116 | Character: 16, 117 | }, 118 | }, 119 | NewText: "ct", 120 | }, 121 | { 122 | Range: lsp.Range{ 123 | Start: lsp.Position{ 124 | Line: 0, 125 | Character: 35, 126 | }, 127 | End: lsp.Position{ 128 | Line: 0, 129 | Character: 37, 130 | }, 131 | }, 132 | NewText: "ct", 133 | }, 134 | }, 135 | }, 136 | }, 137 | }, 138 | pos: lsp.Position{ 139 | Line: 0, 140 | Character: 8, 141 | }, 142 | }, 143 | } 144 | 145 | func TestRenameMain(t *testing.T) { 146 | tx := newTestContext() 147 | tx.setup(t) 148 | defer tx.tearDown() 149 | 150 | cfg := &config.Config{ 151 | Connections: []*database.DBConfig{ 152 | {Driver: "mock"}, 153 | }, 154 | } 155 | tx.addWorkspaceConfig(t, cfg) 156 | 157 | for _, tt := range renameTestCases { 158 | t.Run(tt.name, func(t *testing.T) { 159 | tx.textDocumentDidOpen(t, testFileURI, tt.input) 160 | 161 | params := lsp.RenameParams{ 162 | TextDocument: lsp.TextDocumentIdentifier{ 163 | URI: testFileURI, 164 | }, 165 | Position: tt.pos, 166 | NewName: tt.newName, 167 | } 168 | var got lsp.WorkspaceEdit 169 | err := tx.conn.Call(tx.ctx, "textDocument/rename", params, &got) 170 | if err != nil { 171 | t.Errorf("conn.Call textDocument/rename: %+v", err) 172 | return 173 | } 174 | 175 | if diff := cmp.Diff(tt.output, got); diff != "" { 176 | t.Errorf("unmatch rename edits (- want, + got):\n%s", diff) 177 | } 178 | }) 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /internal/handler/signature_help.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | 8 | "github.com/sourcegraph/jsonrpc2" 9 | "github.com/sqls-server/sqls/internal/database" 10 | "github.com/sqls-server/sqls/internal/lsp" 11 | "github.com/sqls-server/sqls/parser" 12 | "github.com/sqls-server/sqls/parser/parseutil" 13 | "github.com/sqls-server/sqls/token" 14 | ) 15 | 16 | func (s *Server) handleTextDocumentSignatureHelp(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) (result interface{}, err error) { 17 | if req.Params == nil { 18 | return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidParams} 19 | } 20 | 21 | var params lsp.SignatureHelpParams 22 | if err := json.Unmarshal(*req.Params, ¶ms); err != nil { 23 | return nil, err 24 | } 25 | 26 | f, ok := s.files[params.TextDocument.URI] 27 | if !ok { 28 | return nil, fmt.Errorf("document not found: %s", params.TextDocument.URI) 29 | } 30 | 31 | res, err := SignatureHelp(f.Text, params, s.worker.Cache()) 32 | if err != nil { 33 | return nil, err 34 | } 35 | return res, nil 36 | } 37 | 38 | func SignatureHelp(text string, params lsp.SignatureHelpParams, dbCache *database.DBCache) (*lsp.SignatureHelp, error) { 39 | if dbCache == nil { 40 | return nil, nil 41 | } 42 | 43 | parsed, err := parser.Parse(text) 44 | if err != nil { 45 | return nil, err 46 | } 47 | 48 | pos := token.Pos{ 49 | Line: params.Position.Line, 50 | Col: params.Position.Character, 51 | } 52 | nodeWalker := parseutil.NewNodeWalker(parsed, pos) 53 | types := getSignatureHelpTypes(nodeWalker) 54 | 55 | switch { 56 | case signatureHelpIs(types, SignatureHelpTypeInsertValue): 57 | insert, err := parseutil.ExtractInsert(parsed, pos) 58 | if err != nil { 59 | return nil, err 60 | } 61 | if !insert.Enable() { 62 | return nil, err 63 | } 64 | 65 | table := insert.GetTable() 66 | cols := insert.GetColumns() 67 | paramIdx := insert.GetValues().GetIndex(pos) 68 | tableName := table.Name 69 | 70 | params := []lsp.ParameterInformation{} 71 | for _, col := range cols.GetIdentifiers() { 72 | colName := col.String() 73 | colDoc := "" 74 | colDesc, ok := dbCache.Column(tableName, colName) 75 | if ok { 76 | colDoc = colDesc.OnelineDesc() 77 | } 78 | p := lsp.ParameterInformation{ 79 | Label: colName, 80 | Documentation: colDoc, 81 | } 82 | params = append(params, p) 83 | } 84 | 85 | signatureLabel := fmt.Sprintf("%s (%s)", tableName, cols.String()) 86 | sh := &lsp.SignatureHelp{ 87 | Signatures: []lsp.SignatureInformation{ 88 | { 89 | Label: signatureLabel, 90 | Documentation: fmt.Sprintf("%s table columns", tableName), 91 | Parameters: params, 92 | }, 93 | }, 94 | ActiveSignature: 0.0, 95 | ActiveParameter: float64(paramIdx), 96 | } 97 | return sh, nil 98 | default: 99 | // pass 100 | return nil, nil 101 | } 102 | } 103 | 104 | type signatureHelpType int 105 | 106 | const ( 107 | _ signatureHelpType = iota 108 | SignatureHelpTypeInsertValue 109 | SignatureHelpTypeUnknown = 99 110 | ) 111 | 112 | func (sht signatureHelpType) String() string { 113 | switch sht { 114 | case SignatureHelpTypeInsertValue: 115 | return "InsertValue" 116 | default: 117 | return "" 118 | } 119 | } 120 | 121 | func getSignatureHelpTypes(nw *parseutil.NodeWalker) []signatureHelpType { 122 | syntaxPos := parseutil.CheckSyntaxPosition(nw) 123 | types := []signatureHelpType{} 124 | switch { 125 | case syntaxPos == parseutil.InsertValue: 126 | types = []signatureHelpType{ 127 | SignatureHelpTypeInsertValue, 128 | } 129 | default: 130 | // pass 131 | } 132 | return types 133 | } 134 | 135 | func signatureHelpIs(types []signatureHelpType, expect signatureHelpType) bool { 136 | for _, t := range types { 137 | if t == expect { 138 | return true 139 | } 140 | } 141 | return false 142 | } 143 | -------------------------------------------------------------------------------- /internal/handler/signature_help_test.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/google/go-cmp/cmp" 8 | "github.com/sqls-server/sqls/internal/config" 9 | "github.com/sqls-server/sqls/internal/database" 10 | "github.com/sqls-server/sqls/internal/lsp" 11 | ) 12 | 13 | type signatureHelpTestCase struct { 14 | name string 15 | input string 16 | line int 17 | col int 18 | want lsp.SignatureHelp 19 | } 20 | 21 | var signatureHelpTestCases = []signatureHelpTestCase{ 22 | // single record 23 | // input is "insert into city (ID, Name, CountryCode) VALUES (123, NULL, '2020')" 24 | genSingleRecordInsertTest(50, 0), 25 | genSingleRecordInsertTest(52, 0), 26 | genSingleRecordInsertTest(53, 1), 27 | genSingleRecordInsertTest(59, 1), 28 | genSingleRecordInsertTest(60, 2), 29 | genSingleRecordInsertTest(67, 2), 30 | 31 | // multi record 32 | // input is "insert into city (ID, Name, CountryCode) VALUES (123, 'aaa', '2020'), (456, 'bbb', '2021')" 33 | genMultiRecordInsertTest(50, 0), 34 | genMultiRecordInsertTest(52, 0), 35 | genMultiRecordInsertTest(53, 1), 36 | genMultiRecordInsertTest(59, 1), 37 | genMultiRecordInsertTest(60, 2), 38 | genMultiRecordInsertTest(67, 2), 39 | 40 | genMultiRecordInsertTest(72, 0), 41 | genMultiRecordInsertTest(74, 0), 42 | genMultiRecordInsertTest(76, 1), 43 | genMultiRecordInsertTest(81, 1), 44 | genMultiRecordInsertTest(83, 2), 45 | genMultiRecordInsertTest(89, 2), 46 | } 47 | 48 | func genSingleRecordInsertTest(col int, wantActiveParameter int) signatureHelpTestCase { 49 | return signatureHelpTestCase{ 50 | name: fmt.Sprintf("single record %d-%d", col, wantActiveParameter), 51 | input: "insert into city (ID, Name, CountryCode) VALUES (123, NULL, '2020')", 52 | line: 0, 53 | col: col, 54 | want: lsp.SignatureHelp{ 55 | Signatures: []lsp.SignatureInformation{ 56 | { 57 | Label: "city (ID, Name, CountryCode)", 58 | Documentation: "city table columns", 59 | Parameters: []lsp.ParameterInformation{ 60 | { 61 | Label: "ID", 62 | Documentation: "`int(11)` PRI auto_increment", 63 | }, 64 | { 65 | Label: "Name", 66 | Documentation: "`char(35)`", 67 | }, 68 | { 69 | Label: "CountryCode", 70 | Documentation: "`char(3)` MUL", 71 | }, 72 | }, 73 | }, 74 | }, 75 | ActiveSignature: 0.0, 76 | ActiveParameter: float64(wantActiveParameter), 77 | }, 78 | } 79 | } 80 | 81 | func genMultiRecordInsertTest(col int, wantActiveParameter int) signatureHelpTestCase { 82 | return signatureHelpTestCase{ 83 | name: fmt.Sprintf("multi record %d-%d", col, wantActiveParameter), 84 | input: "insert into city (ID, Name, CountryCode) VALUES (123, 'aaa', '2020'), (456, 'bbb', '2021')", 85 | line: 0, 86 | col: col, 87 | want: lsp.SignatureHelp{ 88 | Signatures: []lsp.SignatureInformation{ 89 | { 90 | Label: "city (ID, Name, CountryCode)", 91 | Documentation: "city table columns", 92 | Parameters: []lsp.ParameterInformation{ 93 | { 94 | Label: "ID", 95 | Documentation: "`int(11)` PRI auto_increment", 96 | }, 97 | { 98 | Label: "Name", 99 | Documentation: "`char(35)`", 100 | }, 101 | { 102 | Label: "CountryCode", 103 | Documentation: "`char(3)` MUL", 104 | }, 105 | }, 106 | }, 107 | }, 108 | ActiveSignature: 0.0, 109 | ActiveParameter: float64(wantActiveParameter), 110 | }, 111 | } 112 | } 113 | 114 | func TestSignatureHelpMain(t *testing.T) { 115 | tx := newTestContext() 116 | tx.initServer(t) 117 | defer tx.tearDown() 118 | 119 | cfg := &config.Config{ 120 | Connections: []*database.DBConfig{ 121 | {Driver: "mock"}, 122 | }, 123 | } 124 | tx.addWorkspaceConfig(t, cfg) 125 | 126 | for _, tt := range signatureHelpTestCases { 127 | t.Run(tt.name, func(t *testing.T) { 128 | tx.textDocumentDidOpen(t, testFileURI, tt.input) 129 | 130 | params := lsp.SignatureHelpParams{ 131 | TextDocumentPositionParams: lsp.TextDocumentPositionParams{ 132 | TextDocument: lsp.TextDocumentIdentifier{ 133 | URI: testFileURI, 134 | }, 135 | Position: lsp.Position{ 136 | Line: tt.line, 137 | Character: tt.col, 138 | }, 139 | }, 140 | } 141 | var got lsp.SignatureHelp 142 | if err := tx.conn.Call(tx.ctx, "textDocument/signatureHelp", params, &got); err != nil { 143 | t.Fatal("conn.Call textDocument/signatureHelp:", err) 144 | } 145 | if diff := cmp.Diff(tt.want, got); diff != "" { 146 | t.Errorf("unmatch (- want, + got):\n%s", diff) 147 | } 148 | }) 149 | } 150 | } 151 | 152 | func TestSignatureHelpNoneDBConnection(t *testing.T) { 153 | tx := newTestContext() 154 | tx.initServer(t) 155 | defer tx.tearDown() 156 | 157 | cfg := &config.Config{ 158 | Connections: []*database.DBConfig{}, 159 | } 160 | tx.addWorkspaceConfig(t, cfg) 161 | 162 | uri := "file:///Users/octref/Code/css-test/test.sql" 163 | for _, tt := range signatureHelpTestCases { 164 | t.Run(tt.name, func(t *testing.T) { 165 | tx.textDocumentDidOpen(t, testFileURI, tt.input) 166 | 167 | params := lsp.SignatureHelpParams{ 168 | TextDocumentPositionParams: lsp.TextDocumentPositionParams{ 169 | TextDocument: lsp.TextDocumentIdentifier{ 170 | URI: uri, 171 | }, 172 | Position: lsp.Position{ 173 | Line: tt.line, 174 | Character: tt.col, 175 | }, 176 | }, 177 | } 178 | // Without a DB connection, it is not possible to provide functions using the DB connection, so just make sure that no errors occur. 179 | var got lsp.SignatureHelp 180 | if err := tx.conn.Call(tx.ctx, "textDocument/signatureHelp", params, &got); err != nil { 181 | t.Fatal("conn.Call textDocument/signatureHelp:", err) 182 | } 183 | }) 184 | } 185 | } 186 | -------------------------------------------------------------------------------- /internal/handler/testdata/format/select_arithmetic_expression.golden.sql: -------------------------------------------------------------------------------- 1 | select 2 | 1 + 1, 3 | 2 - 1, 4 | 3 * 2, 5 | 8 / 2, 6 | 1 + 1 * 3, 7 | 3 + 8 / 7, 8 | 1 + 1 * 3, 9 | 312 + 8 / 7, 10 | 4 % 3, 11 | 7 ^ 5 12 | -------------------------------------------------------------------------------- /internal/handler/testdata/format/select_arithmetic_expression.input.sql: -------------------------------------------------------------------------------- 1 | select 1 + 1, 2 - 1, 3 * 2, 8 / 2, 2 | 1 + 1 * 3, 3 + 8 / 7, 3 | 1+1*3, 312+8/7, 4 | 4%3, 7^5 5 | -------------------------------------------------------------------------------- /internal/handler/testdata/format/select_basic.golden.sql: -------------------------------------------------------------------------------- 1 | select 2 | a, 3 | b as bb, 4 | c 5 | from 6 | table 7 | join ( 8 | select 9 | a * 2 as a 10 | from 11 | new_table 12 | ) other 13 | on table.a = other.a 14 | where 15 | c is true 16 | and b between 3 17 | and 4 18 | or d is 'blue' 19 | limit 10 20 | -------------------------------------------------------------------------------- /internal/handler/testdata/format/select_basic.input.sql: -------------------------------------------------------------------------------- 1 | select a, b as bb,c from table 2 | join (select a * 2 as a from new_table) other 3 | on table.a = other.a 4 | where c is true 5 | and b between 3 and 4 6 | or d is 'blue' 7 | limit 10 8 | -------------------------------------------------------------------------------- /internal/handler/testdata/format/select_group_by.golden.sql: -------------------------------------------------------------------------------- 1 | select 2 | a, 3 | b, 4 | c, 5 | sum(x) as sum_x, 6 | count(y) as cnt_y 7 | from 8 | table 9 | group by 10 | a, 11 | b, 12 | c 13 | having 14 | sum(x) > 1 15 | and count(y) > 5 16 | order by 17 | 3, 18 | 2, 19 | 1 20 | -------------------------------------------------------------------------------- /internal/handler/testdata/format/select_group_by.input.sql: -------------------------------------------------------------------------------- 1 | select a, b, c, sum(x) as sum_x, count(y) as cnt_y 2 | from table 3 | group by a,b,c 4 | having sum(x) > 1 5 | and count(y) > 5 6 | order by 3,2,1 7 | -------------------------------------------------------------------------------- /internal/handler/testdata/format/select_group_by_subquery.golden.sql: -------------------------------------------------------------------------------- 1 | select 2 | *, 3 | sum_b + 2 as mod_sum 4 | from 5 | ( 6 | select 7 | a, 8 | sum(b) as sum_b 9 | from 10 | table 11 | group by 12 | a, 13 | z 14 | ) 15 | order by 16 | 1, 17 | 2 18 | -------------------------------------------------------------------------------- /internal/handler/testdata/format/select_group_by_subquery.input.sql: -------------------------------------------------------------------------------- 1 | select *, sum_b + 2 as mod_sum 2 | from ( 3 | select a, sum(b) as sum_b 4 | from table 5 | group by a,z) 6 | order by 1,2 7 | -------------------------------------------------------------------------------- /internal/handler/testdata/format/select_join.golden.sql: -------------------------------------------------------------------------------- 1 | select 2 | * 3 | from 4 | a 5 | join b 6 | on a.one = b.one 7 | left join c 8 | on c.two = a.two 9 | and c.three = a.three 10 | right outer join d 11 | on d.three = a.three 12 | cross join e 13 | on e.four = a.four 14 | join f using ( 15 | one, 16 | two, 17 | three 18 | ) 19 | -------------------------------------------------------------------------------- /internal/handler/testdata/format/select_join.input.sql: -------------------------------------------------------------------------------- 1 | select * from a 2 | join b on a.one = b.one 3 | left join c on c.two = a.two and c.three = a.three 4 | right outer join d on d.three = a.three 5 | cross join e on e.four = a.four 6 | join f using (one, two, three) 7 | -------------------------------------------------------------------------------- /internal/handler/testdata/format/select_statement.golden.sql: -------------------------------------------------------------------------------- 1 | select 2 | a, 3 | case 4 | when a = 0 then 1 5 | when bb = 1 then 1 6 | when c = 2 then 2 7 | else 0 8 | end as d, 9 | extra_col 10 | from 11 | table 12 | where 13 | c is true 14 | and b between 3 15 | and 4 16 | -------------------------------------------------------------------------------- /internal/handler/testdata/format/select_statement.input.sql: -------------------------------------------------------------------------------- 1 | select a, 2 | case when a = 0 3 | then 1 4 | when bb = 1 then 1 5 | when c = 2 then 2 6 | else 0 end as d, 7 | extra_col 8 | from table 9 | where c is true 10 | and b between 3 and 4 11 | -------------------------------------------------------------------------------- /internal/handler/testdata/format/select_statement_with_between.golden.sql: -------------------------------------------------------------------------------- 1 | select 2 | a, 3 | case 4 | when a = 0 then 1 5 | when bb = 1 then 1 6 | when c = 2 then 2 7 | when d between 3 8 | and 5 then 3 9 | else 0 10 | end as d, 11 | extra_col 12 | from 13 | table 14 | where 15 | c is true 16 | and b between 3 17 | and 4 18 | -------------------------------------------------------------------------------- /internal/handler/testdata/format/select_statement_with_between.input.sql: -------------------------------------------------------------------------------- 1 | select a, 2 | case when a = 0 3 | then 1 4 | when bb = 1 then 1 5 | when c = 2 then 2 6 | when d between 3 and 5 then 3 7 | else 0 end as d, 8 | extra_col 9 | from table 10 | where c is true 11 | and b between 3 and 4 12 | -------------------------------------------------------------------------------- /internal/handler/testdata/format_option_space2/select_basic_space2.golden.sql: -------------------------------------------------------------------------------- 1 | select 2 | a, 3 | b as bb, 4 | c 5 | from 6 | table 7 | join ( 8 | select 9 | a * 2 as a 10 | from 11 | new_table 12 | ) other 13 | on table.a = other.a 14 | where 15 | c is true 16 | and b between 3 17 | and 4 18 | or d is 'blue' 19 | limit 10 20 | -------------------------------------------------------------------------------- /internal/handler/testdata/format_option_space2/select_basic_space2.input.sql: -------------------------------------------------------------------------------- 1 | select a, b as bb,c from table 2 | join (select a * 2 as a from new_table) other 3 | on table.a = other.a 4 | where c is true 5 | and b between 3 and 4 6 | or d is 'blue' 7 | limit 10 8 | -------------------------------------------------------------------------------- /internal/handler/testdata/format_option_space4/select_basic_space4.golden.sql: -------------------------------------------------------------------------------- 1 | select 2 | a, 3 | b as bb, 4 | c 5 | from 6 | table 7 | join ( 8 | select 9 | a * 2 as a 10 | from 11 | new_table 12 | ) other 13 | on table.a = other.a 14 | where 15 | c is true 16 | and b between 3 17 | and 4 18 | or d is 'blue' 19 | limit 10 20 | -------------------------------------------------------------------------------- /internal/handler/testdata/format_option_space4/select_basic_space4.input.sql: -------------------------------------------------------------------------------- 1 | select a, b as bb,c from table 2 | join (select a * 2 as a from new_table) other 3 | on table.a = other.a 4 | where c is true 5 | and b between 3 and 4 6 | or d is 'blue' 7 | limit 10 8 | -------------------------------------------------------------------------------- /internal/handler/testdata/upper_case/select_basic.golden.sql: -------------------------------------------------------------------------------- 1 | SELECT 2 | a, 3 | b AS bb, 4 | c 5 | FROM 6 | tbl 7 | JOIN ( 8 | SELECT 9 | a * 2 AS a 10 | FROM 11 | new_table 12 | ) other 13 | ON tbl.a = other.a 14 | WHERE 15 | c IS TRUE 16 | AND b BETWEEN 3 17 | AND 4 18 | OR d IS 'blue' 19 | LIMIT 10 20 | -------------------------------------------------------------------------------- /internal/handler/testdata/upper_case/select_basic.input.sql: -------------------------------------------------------------------------------- 1 | select a, b as bb,c from tbl 2 | join (select a * 2 as a from new_table) other 3 | on tbl.a = other.a 4 | where c is true 5 | and b between 3 and 4 6 | or d is 'blue' 7 | limit 10 8 | -------------------------------------------------------------------------------- /internal/lsp/client.go: -------------------------------------------------------------------------------- 1 | package lsp 2 | 3 | import ( 4 | "context" 5 | "log" 6 | 7 | "github.com/sourcegraph/jsonrpc2" 8 | ) 9 | 10 | type MessageDisplayer interface { 11 | ShowLog(context.Context, string) error 12 | ShowInfo(context.Context, string) error 13 | ShowWarning(context.Context, string) error 14 | ShowError(context.Context, string) error 15 | } 16 | 17 | type Messenger struct { 18 | conn *jsonrpc2.Conn 19 | } 20 | 21 | func NewMessenger(conn *jsonrpc2.Conn) MessageDisplayer { 22 | return &Messenger{ 23 | conn: conn, 24 | } 25 | } 26 | 27 | func (m *Messenger) ShowLog(ctx context.Context, message string) error { 28 | log.Println("Send Message:", message) 29 | params := &ShowMessageParams{ 30 | Type: Log, 31 | Message: message, 32 | } 33 | return m.conn.Notify(ctx, "window/showMessage", params) 34 | } 35 | 36 | func (m *Messenger) ShowInfo(ctx context.Context, message string) error { 37 | log.Println("Send Message:", message) 38 | params := &ShowMessageParams{ 39 | Type: Info, 40 | Message: message, 41 | } 42 | return m.conn.Notify(ctx, "window/showMessage", params) 43 | } 44 | 45 | func (m *Messenger) ShowWarning(ctx context.Context, message string) error { 46 | log.Println("Send Message:", message) 47 | params := &ShowMessageParams{ 48 | Type: Warning, 49 | Message: message, 50 | } 51 | return m.conn.Notify(ctx, "window/showMessage", params) 52 | } 53 | 54 | func (m *Messenger) ShowError(ctx context.Context, message string) error { 55 | log.Println("Send Message:", message) 56 | params := &ShowMessageParams{ 57 | Type: Error, 58 | Message: message, 59 | } 60 | return m.conn.Notify(ctx, "window/showMessage", params) 61 | } 62 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "log" 9 | "os" 10 | "os/exec" 11 | "runtime" 12 | "strings" 13 | 14 | "github.com/sourcegraph/jsonrpc2" 15 | "github.com/urfave/cli/v2" 16 | 17 | "github.com/sqls-server/sqls/internal/config" 18 | "github.com/sqls-server/sqls/internal/handler" 19 | ) 20 | 21 | const name = "sqls" 22 | 23 | const version = "0.2.28" 24 | 25 | var revision = "HEAD" 26 | 27 | func main() { 28 | if err := realMain(); err != nil { 29 | fmt.Fprintln(os.Stderr, err.Error()) 30 | os.Exit(1) 31 | } 32 | os.Exit(0) 33 | } 34 | 35 | func realMain() error { 36 | app := &cli.App{ 37 | Name: "sqls", 38 | Version: fmt.Sprintf("Version:%s, Revision:%s\n", version, revision), 39 | Usage: "An implementation of the Language Server Protocol for SQL.", 40 | Flags: []cli.Flag{ 41 | &cli.StringFlag{ 42 | Name: "log", 43 | Aliases: []string{"l"}, 44 | Usage: "Also log to this file. (in addition to stderr)", 45 | }, 46 | &cli.StringFlag{ 47 | Name: "config", 48 | Aliases: []string{"c"}, 49 | Usage: "Specifies an alternative per-user configuration file. If a configuration file is given on the command line, the workspace option (initializationOptions) will be ignored.", 50 | }, 51 | &cli.BoolFlag{ 52 | Name: "trace", 53 | Aliases: []string{"t"}, 54 | Usage: "Print all requests and responses.", 55 | }, 56 | }, 57 | Commands: cli.Commands{ 58 | { 59 | Name: "config", 60 | Aliases: []string{"c"}, 61 | Usage: "edit config", 62 | Action: func(c *cli.Context) error { 63 | editorEnv := os.Getenv("EDITOR") 64 | if editorEnv == "" { 65 | editorEnv = "vim" 66 | } 67 | return openEditor(editorEnv, config.YamlConfigPath) 68 | }, 69 | }, 70 | }, 71 | Action: func(c *cli.Context) error { 72 | return serve(c) 73 | }, 74 | } 75 | cli.VersionFlag = &cli.BoolFlag{ 76 | Name: "version", 77 | Aliases: []string{"v"}, 78 | Usage: "Print version.", 79 | } 80 | cli.HelpFlag = &cli.BoolFlag{ 81 | Name: "help", 82 | Aliases: []string{"h"}, 83 | Usage: "Print help.", 84 | } 85 | 86 | err := app.Run(os.Args) 87 | if err != nil { 88 | return err 89 | } 90 | 91 | return nil 92 | } 93 | 94 | func serve(c *cli.Context) error { 95 | logfile := c.String("log") 96 | configFile := c.String("config") 97 | trace := c.Bool("trace") 98 | 99 | // Initialize log writer 100 | var logWriter io.Writer 101 | if logfile != "" { 102 | f, err := os.OpenFile(logfile, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0660) 103 | if err != nil { 104 | log.Fatal(err) 105 | } 106 | defer f.Close() 107 | logWriter = io.MultiWriter(os.Stderr, f) 108 | } else { 109 | logWriter = io.MultiWriter(os.Stderr) 110 | } 111 | log.SetOutput(logWriter) 112 | 113 | // Initialize language server 114 | server := handler.NewServer() 115 | defer func() { 116 | if err := server.Stop(); err != nil { 117 | log.Println(err) 118 | } 119 | }() 120 | h := jsonrpc2.HandlerWithError(server.Handle) 121 | 122 | // Load specific config 123 | if configFile != "" { 124 | cfg, err := config.GetConfig(configFile) 125 | if err != nil { 126 | return fmt.Errorf("cannot read specified config, %w", err) 127 | } 128 | server.SpecificFileCfg = cfg 129 | } else { 130 | // Load default config 131 | cfg, err := config.GetDefaultConfig() 132 | if err != nil && !errors.Is(config.ErrNotFoundConfig, err) { 133 | return fmt.Errorf("cannot read default config, %w", err) 134 | } 135 | server.DefaultFileCfg = cfg 136 | } 137 | 138 | // Set connect option 139 | var connOpt []jsonrpc2.ConnOpt 140 | if trace { 141 | connOpt = append(connOpt, jsonrpc2.LogMessages(log.New(logWriter, "", 0))) 142 | } 143 | 144 | // Start language server 145 | log.Println("sqls: reading on stdin, writing on stdout") 146 | <-jsonrpc2.NewConn( 147 | context.Background(), 148 | jsonrpc2.NewBufferedStream(stdrwc{}, jsonrpc2.VSCodeObjectCodec{}), 149 | h, 150 | connOpt..., 151 | ).DisconnectNotify() 152 | log.Println("sqls: connections closed") 153 | 154 | return nil 155 | } 156 | 157 | type stdrwc struct{} 158 | 159 | func (stdrwc) Read(p []byte) (int, error) { 160 | return os.Stdin.Read(p) 161 | } 162 | 163 | func (stdrwc) Write(p []byte) (int, error) { 164 | return os.Stdout.Write(p) 165 | } 166 | 167 | func (stdrwc) Close() error { 168 | if err := os.Stdin.Close(); err != nil { 169 | return err 170 | } 171 | return os.Stdout.Close() 172 | } 173 | 174 | func openEditor(program string, args ...string) error { 175 | cmdargs := strings.Join(args, " ") 176 | command := program + " " + cmdargs 177 | 178 | var cmd *exec.Cmd 179 | if runtime.GOOS == "windows" { 180 | cmd = exec.Command("cmd", "/c", command) 181 | } else { 182 | cmd = exec.Command("sh", "-c", command) 183 | } 184 | cmd.Stdin = os.Stdin 185 | cmd.Stdout = os.Stdout 186 | cmd.Stderr = os.Stderr 187 | return cmd.Run() 188 | } 189 | -------------------------------------------------------------------------------- /parser/parseutil/extract.go: -------------------------------------------------------------------------------- 1 | package parseutil 2 | 3 | import ( 4 | "github.com/sqls-server/sqls/ast" 5 | "github.com/sqls-server/sqls/ast/astutil" 6 | "github.com/sqls-server/sqls/token" 7 | ) 8 | 9 | type ( 10 | prefixParseFn func(reader *astutil.NodeReader) []ast.Node 11 | infixParseFn func(reader *astutil.NodeReader) []ast.Node 12 | ) 13 | 14 | func parsePrefix(reader *astutil.NodeReader, matcher astutil.NodeMatcher, fn prefixParseFn) []ast.Node { 15 | var replaceNodes []ast.Node 16 | for reader.NextNode(false) { 17 | if reader.CurNodeIs(matcher) { 18 | replaceNodes = append(replaceNodes, fn(reader)...) 19 | } else if list, ok := reader.CurNode.(ast.TokenList); ok { 20 | newReader := astutil.NewNodeReader(list) 21 | replaceNodes = append(replaceNodes, parsePrefix(newReader, matcher, fn)...) 22 | } 23 | } 24 | return replaceNodes 25 | } 26 | 27 | func ExtractSelectExpr(parsed ast.TokenList) []ast.Node { 28 | prefixMatcher := astutil.NodeMatcher{ 29 | ExpectKeyword: []string{ 30 | "SELECT", 31 | "ALL", 32 | "DISTINCT", 33 | }, 34 | } 35 | peekMatcher := astutil.NodeMatcher{ 36 | NodeTypes: []ast.NodeType{ 37 | ast.TypeIdentifierList, 38 | ast.TypeIdentifier, 39 | ast.TypeMemberIdentifier, 40 | ast.TypeOperator, 41 | ast.TypeAliased, 42 | ast.TypeParenthesis, 43 | ast.TypeFunctionLiteral, 44 | }, 45 | } 46 | return filterPrefixGroup(astutil.NewNodeReader(parsed), prefixMatcher, peekMatcher) 47 | } 48 | 49 | func ExtractTableReferences(parsed ast.TokenList) []ast.Node { 50 | prefixMatcher := astutil.NodeMatcher{ 51 | ExpectKeyword: []string{ 52 | "FROM", 53 | "UPDATE", 54 | }, 55 | } 56 | peekMatcher := astutil.NodeMatcher{ 57 | NodeTypes: []ast.NodeType{ 58 | ast.TypeIdentifierList, 59 | ast.TypeIdentifier, 60 | ast.TypeMemberIdentifier, 61 | ast.TypeAliased, 62 | }, 63 | } 64 | return filterPrefixGroupOnce(astutil.NewNodeReader(parsed), prefixMatcher, peekMatcher) 65 | } 66 | 67 | func ExtractTableReference(parsed ast.TokenList) []ast.Node { 68 | prefixMatcher := astutil.NodeMatcher{ 69 | ExpectKeyword: []string{ 70 | "INSERT INTO", 71 | "DELETE FROM", 72 | }, 73 | } 74 | peekMatcher := astutil.NodeMatcher{ 75 | NodeTypes: []ast.NodeType{ 76 | ast.TypeIdentifier, 77 | ast.TypeMemberIdentifier, 78 | ast.TypeAliased, 79 | }, 80 | } 81 | return filterPrefixGroup(astutil.NewNodeReader(parsed), prefixMatcher, peekMatcher) 82 | } 83 | 84 | func ExtractTableFactor(parsed ast.TokenList) []ast.Node { 85 | prefixMatcher := astutil.NodeMatcher{ 86 | ExpectKeyword: []string{ 87 | "JOIN", 88 | "INNER JOIN", 89 | "CROSS JOIN", 90 | "OUTER JOIN", 91 | "LEFT JOIN", 92 | "RIGHT JOIN", 93 | "LEFT OUTER JOIN", 94 | "RIGHT OUTER JOIN", 95 | }, 96 | } 97 | peekMatcher := astutil.NodeMatcher{ 98 | NodeTypes: []ast.NodeType{ 99 | ast.TypeIdentifier, 100 | ast.TypeMemberIdentifier, 101 | ast.TypeAliased, 102 | }, 103 | } 104 | return filterPrefixGroup(astutil.NewNodeReader(parsed), prefixMatcher, peekMatcher) 105 | } 106 | 107 | func ExtractWhereCondition(parsed ast.TokenList) []ast.Node { 108 | prefixMatcher := astutil.NodeMatcher{ 109 | ExpectKeyword: []string{ 110 | "WHERE", 111 | }, 112 | } 113 | peekMatcher := astutil.NodeMatcher{ 114 | NodeTypes: []ast.NodeType{ 115 | ast.TypeComparison, 116 | ast.TypeIdentifierList, 117 | }, 118 | } 119 | return filterPrefixGroup(astutil.NewNodeReader(parsed), prefixMatcher, peekMatcher) 120 | } 121 | 122 | func ExtractAliased(parsed ast.TokenList) []ast.Node { 123 | reader := astutil.NewNodeReader(parsed) 124 | matcher := astutil.NodeMatcher{NodeTypes: []ast.NodeType{ast.TypeAliased}} 125 | aliases := reader.FindRecursive(matcher) 126 | return aliases 127 | } 128 | 129 | func ExtractAliasedIdentifier(parsed ast.TokenList) []ast.Node { 130 | reader := astutil.NewNodeReader(parsed) 131 | matcher := astutil.NodeMatcher{NodeTypes: []ast.NodeType{ast.TypeAliased}} 132 | aliases := reader.FindRecursive(matcher) 133 | 134 | results := []ast.Node{} 135 | for _, node := range aliases { 136 | alias, ok := node.(*ast.Aliased) 137 | if !ok { 138 | continue 139 | } 140 | list, ok := alias.RealName.(ast.TokenList) 141 | if !ok { 142 | results = append(results, node) 143 | continue 144 | } 145 | if isSubQuery(list) { 146 | continue 147 | } 148 | results = append(results, node) 149 | } 150 | return results 151 | } 152 | 153 | func ExtractInsertColumns(parsed ast.TokenList) []ast.Node { 154 | insertTableIdentifier := astutil.NodeMatcher{ 155 | NodeTypes: []ast.NodeType{ 156 | ast.TypeIdentifier, 157 | ast.TypeMemberIdentifier, 158 | ast.TypeAliased, 159 | }, 160 | } 161 | return parsePrefix(astutil.NewNodeReader(parsed), insertTableIdentifier, parseInsertColumns) 162 | } 163 | 164 | func parseInsertColumns(reader *astutil.NodeReader) []ast.Node { 165 | insertColumnsParenthesis := astutil.NodeMatcher{ 166 | NodeTypes: []ast.NodeType{ 167 | ast.TypeParenthesis, 168 | }, 169 | } 170 | 171 | if !reader.PeekNodeIs(true, insertColumnsParenthesis) { 172 | return []ast.Node{} 173 | } 174 | 175 | _, parenthesisNode := reader.PeekNode(true) 176 | parenthesis, ok := parenthesisNode.(*ast.Parenthesis) 177 | if !ok { 178 | return []ast.Node{} 179 | } 180 | 181 | inner, ok := parenthesis.Inner().(*ast.IdentifierList) 182 | if ok { 183 | return []ast.Node{inner} 184 | } 185 | list := parenthesis.Inner().GetTokens() 186 | if len(list) > 0 { 187 | firstToken, ok := list[0].(*ast.IdentifierList) 188 | if ok { 189 | return []ast.Node{firstToken} 190 | } 191 | } 192 | return []ast.Node{} 193 | } 194 | 195 | func ExtractInsertValues(parsed ast.TokenList, pos token.Pos) []ast.Node { 196 | insertTableIdentifier := astutil.NodeMatcher{ 197 | ExpectTokens: []token.Kind{ 198 | token.Comma, 199 | }, 200 | ExpectKeyword: []string{ 201 | "VALUES", 202 | }, 203 | } 204 | values := parsePrefix(astutil.NewNodeReader(parsed), insertTableIdentifier, parseInsertValues) 205 | for _, v := range values { 206 | if astutil.IsEnclose(v, pos) { 207 | return []ast.Node{v} 208 | } 209 | } 210 | return []ast.Node{} 211 | } 212 | 213 | func parseInsertValues(reader *astutil.NodeReader) []ast.Node { 214 | insertColumnsParenthesis := astutil.NodeMatcher{ 215 | NodeTypes: []ast.NodeType{ 216 | ast.TypeParenthesis, 217 | }, 218 | } 219 | if !reader.PeekNodeIs(true, insertColumnsParenthesis) { 220 | return []ast.Node{} 221 | } 222 | 223 | _, parenthesisNode := reader.PeekNode(true) 224 | parenthesis, ok := parenthesisNode.(*ast.Parenthesis) 225 | if !ok { 226 | return []ast.Node{} 227 | } 228 | identList, ok := parenthesis.Inner().GetTokens()[0].(*ast.IdentifierList) 229 | if !ok { 230 | return []ast.Node{} 231 | } 232 | return []ast.Node{identList} 233 | } 234 | 235 | func filterPrefixGroup(reader *astutil.NodeReader, prefixMatcher astutil.NodeMatcher, peekMatcher astutil.NodeMatcher) []ast.Node { 236 | var results []ast.Node 237 | for reader.NextNode(false) { 238 | if reader.CurNodeIs(prefixMatcher) && reader.PeekNodeIs(true, peekMatcher) { 239 | _, node := reader.PeekNode(true) 240 | results = append(results, node) 241 | } 242 | if list, ok := reader.CurNode.(ast.TokenList); ok { 243 | newReader := astutil.NewNodeReader(list) 244 | results = append(results, filterPrefixGroup(newReader, prefixMatcher, peekMatcher)...) 245 | } 246 | } 247 | return results 248 | } 249 | 250 | func filterPrefixGroupOnce(reader *astutil.NodeReader, prefixMatcher astutil.NodeMatcher, peekMatcher astutil.NodeMatcher) []ast.Node { 251 | results := filterPrefixGroup(reader, prefixMatcher, peekMatcher) 252 | if len(results) > 0 { 253 | return []ast.Node{results[0]} 254 | } 255 | return nil 256 | } 257 | -------------------------------------------------------------------------------- /parser/parseutil/idenfier.go: -------------------------------------------------------------------------------- 1 | package parseutil 2 | 3 | import ( 4 | "github.com/sqls-server/sqls/ast" 5 | "github.com/sqls-server/sqls/ast/astutil" 6 | "github.com/sqls-server/sqls/token" 7 | ) 8 | 9 | func ExtractIdenfiers(parsed ast.TokenList, pos token.Pos) ([]ast.Node, error) { 10 | stmt, err := extractFocusedStatement(parsed, pos) 11 | if err != nil { 12 | return nil, err 13 | } 14 | 15 | identifierMatcher := astutil.NodeMatcher{ 16 | NodeTypes: []ast.NodeType{ 17 | ast.TypeIdentifier, 18 | }, 19 | } 20 | return parsePrefix(astutil.NewNodeReader(stmt), identifierMatcher, parseIdentifier), nil 21 | } 22 | 23 | func parseIdentifier(reader *astutil.NodeReader) []ast.Node { 24 | return []ast.Node{reader.CurNode} 25 | } 26 | -------------------------------------------------------------------------------- /parser/parseutil/insert.go: -------------------------------------------------------------------------------- 1 | package parseutil 2 | 3 | import ( 4 | "github.com/sqls-server/sqls/ast" 5 | "github.com/sqls-server/sqls/token" 6 | ) 7 | 8 | type Insert struct { 9 | Tables []*TableInfo 10 | Columns []*ast.IdentifierList 11 | Values []*ast.IdentifierList 12 | } 13 | 14 | func (i *Insert) Enable() bool { 15 | if len(i.Tables) == 0 { 16 | return false 17 | } 18 | if len(i.Columns) == 0 { 19 | return false 20 | } 21 | if len(i.Values) == 0 { 22 | return false 23 | } 24 | return true 25 | } 26 | 27 | func (i *Insert) GetTable() *TableInfo { 28 | if len(i.Tables) == 0 { 29 | return nil 30 | } 31 | return i.Tables[0] 32 | } 33 | 34 | func (i *Insert) GetColumns() *ast.IdentifierList { 35 | if len(i.Columns) == 0 { 36 | return nil 37 | } 38 | return i.Columns[0] 39 | } 40 | 41 | func (i *Insert) GetValues() *ast.IdentifierList { 42 | if len(i.Values) == 0 { 43 | return nil 44 | } 45 | return i.Values[0] 46 | } 47 | 48 | func ExtractInsert(parsed ast.TokenList, pos token.Pos) (*Insert, error) { 49 | stmt, err := extractFocusedStatement(parsed, pos) 50 | if err != nil { 51 | return nil, err 52 | } 53 | 54 | tables, err := ExtractTable(parsed, pos) 55 | if err != nil { 56 | return nil, err 57 | } 58 | 59 | columns := []*ast.IdentifierList{} 60 | columnsNodes := ExtractInsertColumns(stmt) 61 | for _, n := range columnsNodes { 62 | c, ok := n.(*ast.IdentifierList) 63 | if ok { 64 | columns = append(columns, c) 65 | } 66 | } 67 | 68 | values := []*ast.IdentifierList{} 69 | valuesNodes := ExtractInsertValues(stmt, pos) 70 | for _, n := range valuesNodes { 71 | n, ok := n.(*ast.IdentifierList) 72 | if ok { 73 | values = append(values, n) 74 | } 75 | } 76 | 77 | res := &Insert{ 78 | Tables: tables, 79 | Columns: columns, 80 | Values: values, 81 | } 82 | return res, nil 83 | } 84 | -------------------------------------------------------------------------------- /parser/parseutil/insert_test.go: -------------------------------------------------------------------------------- 1 | package parseutil 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/google/go-cmp/cmp" 7 | "github.com/sqls-server/sqls/token" 8 | ) 9 | 10 | func TestExtractInsert(t *testing.T) { 11 | testcases := []struct { 12 | name string 13 | input string 14 | pos token.Pos 15 | tbl *TableInfo 16 | cols string 17 | vals string 18 | }{ 19 | { 20 | name: "single", 21 | input: "insert into city (ID, Name, CountryCode) VALUES (123, 'aaa', '2020')", 22 | pos: token.Pos{ 23 | Line: 0, 24 | Col: 50, 25 | }, 26 | tbl: &TableInfo{ 27 | Name: "city", 28 | }, 29 | cols: "ID, Name, CountryCode", 30 | vals: "123, 'aaa', '2020'", 31 | }, 32 | { 33 | name: "multi forcus first", 34 | input: "insert into city (ID, Name, CountryCode) VALUES (123, 'aaa', '2020'), (456, 'bbb', '2021')", 35 | pos: token.Pos{ 36 | Line: 0, 37 | Col: 50, 38 | }, 39 | tbl: &TableInfo{ 40 | Name: "city", 41 | }, 42 | cols: "ID, Name, CountryCode", 43 | vals: "123, 'aaa', '2020'", 44 | }, 45 | { 46 | name: "multi forcus second", 47 | input: "insert into city (ID, Name, CountryCode) VALUES (123, 'aaa', '2020'), (456, 'bbb', '2021')", 48 | pos: token.Pos{ 49 | Line: 0, 50 | Col: 72, 51 | }, 52 | tbl: &TableInfo{ 53 | Name: "city", 54 | }, 55 | cols: "ID, Name, CountryCode", 56 | vals: "456, 'bbb', '2021'", 57 | }, 58 | } 59 | 60 | for _, tt := range testcases { 61 | t.Run(tt.name, func(t *testing.T) { 62 | stmt := initExtractTable(t, tt.input) 63 | got, err := ExtractInsert(stmt, tt.pos) 64 | if err != nil { 65 | t.Fatalf("error: %+v", err) 66 | } 67 | if d := cmp.Diff(tt.tbl, got.GetTable()); d != "" { 68 | t.Errorf("unmatched table info(-want, +got): %s", d) 69 | } 70 | if d := cmp.Diff(tt.cols, got.GetColumns().String()); d != "" { 71 | t.Errorf("unmatched columns info(-want, +got): %s", d) 72 | } 73 | if d := cmp.Diff(tt.vals, got.GetValues().String()); d != "" { 74 | t.Errorf("unmatched values info(-want, +got): %s", d) 75 | } 76 | }) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /parser/parseutil/position.go: -------------------------------------------------------------------------------- 1 | package parseutil 2 | 3 | import ( 4 | "github.com/sqls-server/sqls/ast" 5 | "github.com/sqls-server/sqls/ast/astutil" 6 | "github.com/sqls-server/sqls/token" 7 | ) 8 | 9 | type SyntaxPosition string 10 | 11 | const ( 12 | ColName SyntaxPosition = "col_name" 13 | SelectExpr SyntaxPosition = "select_expr" 14 | AliasName SyntaxPosition = "alias_name" 15 | WhereCondition SyntaxPosition = "where_condition" 16 | CaseValue SyntaxPosition = "case_value" 17 | TableReference SyntaxPosition = "table_reference" 18 | InsertColumn SyntaxPosition = "insert_column" 19 | InsertValue SyntaxPosition = "insert_value" 20 | JoinClause SyntaxPosition = "join_clause" 21 | JoinOn SyntaxPosition = "join_on" 22 | Unknown SyntaxPosition = "unknown" 23 | ) 24 | 25 | func CheckSyntaxPosition(nw *NodeWalker) SyntaxPosition { 26 | var res SyntaxPosition 27 | switch { 28 | case nw.PrevNodesIs(true, genKeywordMatcher([]string{ 29 | // UPDATE Statement 30 | "SET", 31 | // SELECT Statement 32 | "ORDER BY", 33 | "GROUP BY", 34 | })): 35 | res = ColName 36 | case nw.PrevNodesIs(true, genKeywordMatcher([]string{ 37 | // SELECT Statement 38 | "ALL", 39 | "DISTINCT", 40 | "DISTINCTROW", 41 | "SELECT", 42 | })): 43 | res = SelectExpr 44 | case nw.PrevNodesIs(true, genKeywordMatcher([]string{ 45 | // Alias 46 | "AS", 47 | })): 48 | res = AliasName 49 | case nw.PrevNodesIs(true, genKeywordMatcher([]string{ 50 | // WHERE Clause 51 | "WHERE", 52 | "HAVING", 53 | // Operator 54 | "AND", 55 | "OR", 56 | "XOR", 57 | })): 58 | res = WhereCondition 59 | case nw.PrevNodesIs(true, genKeywordMatcher([]string{ 60 | // CASE Statement 61 | "CASE", 62 | "WHEN", 63 | "THEN", 64 | "ELSE", 65 | })): 66 | res = CaseValue 67 | case nw.PrevNodesIs(true, genKeywordMatcher([]string{ 68 | // SELECT Statement 69 | "FROM", 70 | // UPDATE Statement 71 | "UPDATE", 72 | // DELETE Statement 73 | "DELETE FROM", 74 | // INSERT Statement 75 | "INSERT INTO", 76 | // JOIN Clause 77 | "CROSS JOIN", 78 | // DESCRIBE Statement 79 | "DESCRIBE", 80 | "DESC", 81 | // TRUNCATE Statement 82 | "TRUNCATE", 83 | })): 84 | res = TableReference 85 | case nw.PrevNodesIs(true, genKeywordMatcher([]string{ 86 | "ON", 87 | })): 88 | res = getJoinOnCondition(nw) 89 | case nw.PrevNodesIs(true, genKeywordMatcher([]string{ 90 | "JOIN", 91 | "INNER JOIN", 92 | "OUTER JOIN", 93 | "LEFT JOIN", 94 | "RIGHT JOIN", 95 | "LEFT OUTER JOIN", 96 | "RIGHT OUTER JOIN", 97 | })): 98 | res = getJoinCondition(nw) 99 | case isInsertColumns(nw): 100 | if isInsertValues(nw) { 101 | res = InsertValue 102 | } else { 103 | res = InsertColumn 104 | } 105 | default: 106 | res = Unknown 107 | } 108 | return res 109 | } 110 | 111 | func getJoinCondition(nw *NodeWalker) SyntaxPosition { 112 | for _, n := range nw.Paths { 113 | if n.PeekNodeIs(true, genKeywordMatcher([]string{"ON"})) { 114 | return TableReference 115 | } 116 | } 117 | return JoinClause 118 | } 119 | func getJoinOnCondition(nw *NodeWalker) SyntaxPosition { 120 | switch { 121 | case nw.CurNodeIs(genTokenMatcher([]token.Kind{token.Period})): 122 | return ColName 123 | case nw.CurNodeIs(genTokenMatcher([]token.Kind{token.Whitespace})): 124 | if !nw.PrevNodesIs(true, astutil.NodeMatcher{ 125 | ExpectTokens: []token.Kind{token.Eq}}) { 126 | return JoinOn 127 | } 128 | } 129 | return WhereCondition 130 | } 131 | 132 | func genKeywordMatcher(keywords []string) astutil.NodeMatcher { 133 | return astutil.NodeMatcher{ 134 | ExpectKeyword: keywords, 135 | } 136 | } 137 | 138 | func genTokenMatcher(tokens []token.Kind) astutil.NodeMatcher { 139 | return astutil.NodeMatcher{ 140 | ExpectTokens: tokens, 141 | } 142 | } 143 | 144 | func isInsertColumns(nw *NodeWalker) bool { 145 | ParenthesisMatcher := astutil.NodeMatcher{ 146 | NodeTypes: []ast.NodeType{ 147 | ast.TypeParenthesis, 148 | }, 149 | } 150 | return nw.CurNodeIs(ParenthesisMatcher) 151 | } 152 | 153 | func isInsertValues(nw *NodeWalker) bool { 154 | ParenthesisMatcher := astutil.NodeMatcher{ 155 | NodeTypes: []ast.NodeType{ 156 | ast.TypeParenthesis, 157 | }, 158 | } 159 | depth, ok := nw.CurNodeDepth(ParenthesisMatcher) 160 | if ok { 161 | if nw.PrevNodesIsWithDepth(true, genKeywordMatcher([]string{"VALUES"}), depth) { 162 | return true 163 | } 164 | if nw.PrevNodesIsWithDepth(true, genTokenMatcher([]token.Kind{token.Comma}), depth) { 165 | return true 166 | } 167 | } 168 | return false 169 | } 170 | -------------------------------------------------------------------------------- /parser/parseutil/position_test.go: -------------------------------------------------------------------------------- 1 | package parseutil 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/sqls-server/sqls/parser" 7 | "github.com/sqls-server/sqls/token" 8 | ) 9 | 10 | func TestCheckSyntaxPosition(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | text string 14 | pos token.Pos 15 | want SyntaxPosition 16 | }{ 17 | { 18 | name: "insert column", 19 | text: "insert into city (ID, Name, CountryCode) VALUES (123, 'aaa', '2020')", 20 | pos: token.Pos{ 21 | Line: 0, 22 | Col: 57, 23 | }, 24 | want: InsertValue, 25 | }, 26 | { 27 | name: "on lparen", 28 | text: "insert into city (ID, Name, CountryCode) VALUES (", 29 | pos: token.Pos{ 30 | Line: 0, 31 | Col: 49, 32 | }, 33 | want: InsertValue, 34 | }, 35 | { 36 | name: "with space first param", 37 | text: "insert into city (ID, Name, CountryCode) VALUES ( ", 38 | pos: token.Pos{ 39 | Line: 0, 40 | Col: 52, 41 | }, 42 | want: InsertValue, 43 | }, 44 | { 45 | name: "second param", 46 | text: "insert into city (ID, Name, CountryCode) VALUES (123, ", 47 | pos: token.Pos{ 48 | Line: 0, 49 | Col: 54, 50 | }, 51 | want: InsertValue, 52 | }, 53 | { 54 | name: "white space with second param", 55 | text: "insert into city (ID, Name, CountryCode) VALUES (123, ", 56 | pos: token.Pos{ 57 | Line: 0, 58 | Col: 56, 59 | }, 60 | want: InsertValue, 61 | }, 62 | { 63 | name: "third param", 64 | text: "insert into city (ID, Name, CountryCode) VALUES (123, 'aaa', '2020' ", 65 | pos: token.Pos{ 66 | Line: 0, 67 | Col: 68, 68 | }, 69 | want: InsertValue, 70 | }, 71 | { 72 | name: "white space with third param", 73 | text: "insert into city (ID, Name, CountryCode) VALUES (123, 'aaa', '2020' ", 74 | pos: token.Pos{ 75 | Line: 0, 76 | Col: 70, 77 | }, 78 | want: InsertValue, 79 | }, 80 | { 81 | name: "join tables", 82 | text: "select CountryCode from city join ", 83 | pos: token.Pos{ 84 | Line: 0, 85 | Col: 34, 86 | }, 87 | want: JoinClause, 88 | }, 89 | { 90 | name: "join filtered tables", 91 | text: "select CountryCode from city join co", 92 | pos: token.Pos{ 93 | Line: 0, 94 | Col: 36, 95 | }, 96 | want: JoinClause, 97 | }, 98 | { 99 | name: "left join tables", 100 | text: "select CountryCode from city left join ", 101 | pos: token.Pos{ 102 | Line: 0, 103 | Col: 39, 104 | }, 105 | want: JoinClause, 106 | }, 107 | { 108 | name: "left outer join tables", 109 | text: "select CountryCode from city left outer join ", 110 | pos: token.Pos{ 111 | Line: 0, 112 | Col: 45, 113 | }, 114 | want: JoinClause, 115 | }, 116 | { 117 | name: "join on columns", 118 | text: "select * from city left join country on ", 119 | pos: token.Pos{ 120 | Line: 0, 121 | Col: 40, 122 | }, 123 | want: JoinOn, 124 | }, 125 | { 126 | name: "join on filtered columns", 127 | text: "select * from city left join country on co", 128 | pos: token.Pos{ 129 | Line: 0, 130 | Col: 42, 131 | }, 132 | want: WhereCondition, 133 | }, 134 | { 135 | name: "join on table", 136 | text: "select * from city left join country on country.", 137 | pos: token.Pos{ 138 | Line: 0, 139 | Col: 48, 140 | }, 141 | want: ColName, 142 | }, 143 | { 144 | name: "join on ", 145 | text: "select * from city left join country on country.Code =", 146 | pos: token.Pos{ 147 | Line: 0, 148 | Col: 54, 149 | }, 150 | want: WhereCondition, 151 | }, 152 | { 153 | name: "join on ", 154 | text: "select * from city left join country on country.Code = ", 155 | pos: token.Pos{ 156 | Line: 0, 157 | Col: 55, 158 | }, 159 | want: WhereCondition, 160 | }, 161 | { 162 | name: "join on ref tables filtered", 163 | text: "select * from city left join country on country.Code = ci", 164 | pos: token.Pos{ 165 | Line: 0, 166 | Col: 57, 167 | }, 168 | want: WhereCondition, 169 | }, 170 | { 171 | name: "join on ref table", 172 | text: "select * from city left join country on country.Code = city.", 173 | pos: token.Pos{ 174 | Line: 0, 175 | Col: 60, 176 | }, 177 | want: ColName, 178 | }, 179 | { 180 | name: "join alias snippet", 181 | text: "select * from city c left join country c1 on c1.Code", 182 | pos: token.Pos{ 183 | Line: 0, 184 | Col: 39, 185 | }, 186 | want: TableReference, 187 | }, 188 | } 189 | for _, tt := range tests { 190 | t.Run(tt.name, func(t *testing.T) { 191 | parsed, err := parser.Parse(tt.text) 192 | if err != nil { 193 | t.Errorf("parse error, %v", err) 194 | return 195 | } 196 | 197 | nodeWalker := NewNodeWalker(parsed, tt.pos) 198 | if got := CheckSyntaxPosition(nodeWalker); got != tt.want { 199 | t.Errorf("unmatch syntax position got %v, want %v", got, tt.want) 200 | } 201 | }) 202 | } 203 | } 204 | -------------------------------------------------------------------------------- /parser/parseutil/walk.go: -------------------------------------------------------------------------------- 1 | package parseutil 2 | 3 | import ( 4 | "github.com/sqls-server/sqls/ast" 5 | "github.com/sqls-server/sqls/ast/astutil" 6 | "github.com/sqls-server/sqls/token" 7 | ) 8 | 9 | type NodeWalker struct { 10 | Paths []*astutil.NodeReader 11 | Index int 12 | } 13 | 14 | func astPaths(reader *astutil.NodeReader, pos token.Pos) []*astutil.NodeReader { 15 | paths := []*astutil.NodeReader{} 16 | for reader.NextNode(false) { 17 | if reader.CurNodeEncloseIs(pos) { 18 | paths = append(paths, reader) 19 | if list, ok := reader.CurNode.(ast.TokenList); ok { 20 | newReader := astutil.NewNodeReader(list) 21 | return append(paths, astPaths(newReader, pos)...) 22 | } 23 | return paths 24 | } 25 | } 26 | return paths 27 | } 28 | 29 | func NewNodeWalker(root ast.TokenList, pos token.Pos) *NodeWalker { 30 | return &NodeWalker{ 31 | Paths: astPaths(astutil.NewNodeReader(root), pos), 32 | } 33 | } 34 | 35 | func (nw *NodeWalker) CurNodeIs(matcher astutil.NodeMatcher) bool { 36 | for _, reader := range nw.Paths { 37 | if reader.CurNodeIs(matcher) { 38 | return true 39 | } 40 | } 41 | return false 42 | } 43 | 44 | func (nw *NodeWalker) CurNodeDepth(matcher astutil.NodeMatcher) (int, bool) { 45 | for i, reader := range nw.Paths { 46 | if reader.CurNodeIs(matcher) { 47 | return i, true 48 | } 49 | } 50 | return 0, false 51 | } 52 | 53 | func (nw *NodeWalker) CurNodeMatches(matcher astutil.NodeMatcher) []ast.Node { 54 | matches := []ast.Node{} 55 | for _, reader := range nw.Paths { 56 | if reader.CurNodeIs(matcher) { 57 | matches = append(matches, reader.CurNode) 58 | } 59 | } 60 | return matches 61 | } 62 | 63 | func (nw *NodeWalker) CurNodeTopMatched(matcher astutil.NodeMatcher) ast.Node { 64 | matches := nw.CurNodeMatches(matcher) 65 | if len(matches) == 0 { 66 | return nil 67 | } 68 | return matches[0] 69 | } 70 | 71 | func (nw *NodeWalker) CurNodeBottomMatched(matcher astutil.NodeMatcher) ast.Node { 72 | matches := nw.CurNodeMatches(matcher) 73 | if len(matches) == 0 { 74 | return nil 75 | } 76 | return matches[len(matches)-1] 77 | } 78 | 79 | func (nw *NodeWalker) CurNodes() []ast.Node { 80 | results := []ast.Node{} 81 | for _, reader := range nw.Paths { 82 | results = append(results, reader.CurNode) 83 | } 84 | return results 85 | } 86 | 87 | func (nw *NodeWalker) PrevNodes(ignoreWitespace bool) []ast.Node { 88 | results := []ast.Node{} 89 | for _, reader := range nw.Paths { 90 | _, node := reader.PrevNode(ignoreWitespace) 91 | results = append(results, node) 92 | } 93 | return results 94 | } 95 | 96 | func (nw *NodeWalker) PrevNodesIs(ignoreWitespace bool, matcher astutil.NodeMatcher) bool { 97 | for _, reader := range nw.Paths { 98 | if reader.PrevNodeIs(ignoreWitespace, matcher) { 99 | return true 100 | } 101 | } 102 | return false 103 | } 104 | 105 | func (nw *NodeWalker) PrevNodesIsWithDepth(ignoreWitespace bool, matcher astutil.NodeMatcher, depth int) bool { 106 | reader := nw.Paths[depth] 107 | if reader.PrevNodeIs(ignoreWitespace, matcher) { 108 | return true 109 | } 110 | return false 111 | } 112 | -------------------------------------------------------------------------------- /schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "http://json-schema.org/draft-04/schema", 3 | "additionalProperties": false, 4 | "definitions": { 5 | "connection-definition": { 6 | "description": "Database connections", 7 | "type": "array", 8 | "items": { 9 | "additionalProperties": false, 10 | "type": "object", 11 | "properties": { 12 | "alias": { 13 | "description": "Connection alias name. Optional", 14 | "type": "string" 15 | }, 16 | "driver": { 17 | "description": "mysql, postgresql, sqlite3, mssql, h2. Required", 18 | "type": "string", 19 | "enum": [ 20 | "mysql", 21 | "postgresql", 22 | "sqlite3", 23 | "mssql", 24 | "h2" 25 | ] 26 | }, 27 | "dataSourceName": { 28 | "description": "Data source name", 29 | "type": "string" 30 | }, 31 | "proto": { 32 | "description": "tcp, udp, unix", 33 | "type": "string", 34 | "enum": [ 35 | "tcp", 36 | "udp", 37 | "unix" 38 | ] 39 | }, 40 | "user": { 41 | "description": "User name", 42 | "type": "string" 43 | }, 44 | "passwd": { 45 | "description": "Password", 46 | "type": "string" 47 | }, 48 | "host": { 49 | "description": "Host", 50 | "type": "string" 51 | }, 52 | "port": { 53 | "description": "Port", 54 | "type": "number" 55 | }, 56 | "path": { 57 | "description": "unix socket path", 58 | "type": "string" 59 | }, 60 | "dbName": { 61 | "description": "Database name", 62 | "type": "string" 63 | }, 64 | "params": { 65 | "description": "Option params. Optional", 66 | "type": "object", 67 | "properties": {} 68 | }, 69 | "sshConfig": { 70 | "description": "ssh config. Optional", 71 | "type": "object", 72 | "properties": { 73 | "host": { 74 | "description": "ssh host. Required", 75 | "type": "string" 76 | }, 77 | "port": { 78 | "description": "ssh port. Required", 79 | "type": "number" 80 | }, 81 | "user": { 82 | "description": "ssh user. Optional", 83 | "type": "string" 84 | }, 85 | "privateKey": { 86 | "description": "private key path. Required", 87 | "type": "string" 88 | }, 89 | "passPhrase": { 90 | "description": "passPhrase. Optional", 91 | "type": "string" 92 | } 93 | } 94 | } 95 | } 96 | } 97 | } 98 | }, 99 | "properties": { 100 | "lowercaseKeywords": { 101 | "description": "Set to true to use lowercase keywords instead of uppercase.", 102 | "type": "boolean" 103 | }, 104 | "connections": { 105 | "$ref": "#/definitions/connection-definition" 106 | } 107 | }, 108 | "title": "sqls", 109 | "type": "object" 110 | } 111 | -------------------------------------------------------------------------------- /script/help_categories.sql: -------------------------------------------------------------------------------- 1 | SELECT hc1.help_category_id, 2 | hc1.name, 3 | hc2.help_category_id, 4 | hc2.name 5 | FROM help_category AS hc1 6 | JOIN help_category AS hc2 ON hc1.help_category_id = hc2.parent_category_id 7 | ORDER BY hc1.help_category_id 8 | -------------------------------------------------------------------------------- /script/help_functions_mysql56.sql: -------------------------------------------------------------------------------- 1 | SELECT distinct hk.name as keyword_name 2 | FROM help_keyword AS hk 3 | LEFT JOIN help_relation AS hr ON hr.help_keyword_id = hk.help_keyword_id 4 | LEFT JOIN help_topic AS ht ON ht.help_topic_id = hr.help_topic_id 5 | LEFT JOIN help_category AS hc ON hc.help_category_id = ht.help_category_id 6 | where hc.parent_category_id IN (4, 8, 21) 7 | order by keyword_name 8 | -------------------------------------------------------------------------------- /script/help_functions_mysql57.sql: -------------------------------------------------------------------------------- 1 | SELECT distinct hk.name as keyword_name 2 | FROM help_keyword AS hk 3 | LEFT JOIN help_relation AS hr ON hr.help_keyword_id = hk.help_keyword_id 4 | LEFT JOIN help_topic AS ht ON ht.help_topic_id = hr.help_topic_id 5 | LEFT JOIN help_category AS hc ON hc.help_category_id = ht.help_category_id 6 | where hc.parent_category_id IN (6, 9, 20) 7 | order by keyword_name 8 | -------------------------------------------------------------------------------- /script/help_functions_mysql8.sql: -------------------------------------------------------------------------------- 1 | SELECT distinct hk.name as keyword_name 2 | FROM help_keyword AS hk 3 | LEFT JOIN help_relation AS hr ON hr.help_keyword_id = hk.help_keyword_id 4 | LEFT JOIN help_topic AS ht ON ht.help_topic_id = hr.help_topic_id 5 | LEFT JOIN help_category AS hc ON hc.help_category_id = ht.help_category_id 6 | where hc.parent_category_id IN (4, 7, 22) 7 | order by keyword_name 8 | -------------------------------------------------------------------------------- /script/help_keywords_mysql56.sql: -------------------------------------------------------------------------------- 1 | SELECT distinct hk.name as keyword_name 2 | FROM help_keyword AS hk 3 | LEFT JOIN help_relation AS hr ON hr.help_keyword_id = hk.help_keyword_id 4 | LEFT JOIN help_topic AS ht ON ht.help_topic_id = hr.help_topic_id 5 | LEFT JOIN help_category AS hc ON hc.help_category_id = ht.help_category_id 6 | where hc.parent_category_id NOT IN (4, 8, 21) 7 | order by keyword_name 8 | -------------------------------------------------------------------------------- /script/help_keywords_mysql57.sql: -------------------------------------------------------------------------------- 1 | SELECT distinct hk.name as keyword_name 2 | FROM help_keyword AS hk 3 | LEFT JOIN help_relation AS hr ON hr.help_keyword_id = hk.help_keyword_id 4 | LEFT JOIN help_topic AS ht ON ht.help_topic_id = hr.help_topic_id 5 | LEFT JOIN help_category AS hc ON hc.help_category_id = ht.help_category_id 6 | where hc.parent_category_id NOT IN (6, 9, 20) 7 | order by keyword_name 8 | -------------------------------------------------------------------------------- /script/help_keywords_mysql8.sql: -------------------------------------------------------------------------------- 1 | SELECT distinct hk.name as keyword_name 2 | FROM help_keyword AS hk 3 | LEFT JOIN help_relation AS hr ON hr.help_keyword_id = hk.help_keyword_id 4 | LEFT JOIN help_topic AS ht ON ht.help_topic_id = hr.help_topic_id 5 | LEFT JOIN help_category AS hc ON hc.help_category_id = ht.help_category_id 6 | where hc.parent_category_id NOT IN (4, 7, 22) 7 | order by keyword_name 8 | -------------------------------------------------------------------------------- /token/kind.go: -------------------------------------------------------------------------------- 1 | package token 2 | 3 | type Kind int 4 | 5 | //go:generate stringer -type Kind kind.go 6 | const ( 7 | // A keyword (like SELECT) 8 | SQLKeyword Kind = iota 9 | // Numeric literal 10 | Number 11 | // A character that cloud not be tokenized 12 | Char 13 | // Single quoted string i.e: 'string' 14 | SingleQuotedString 15 | // National string i.e: N'string' 16 | NationalStringLiteral 17 | // Comma 18 | Comma 19 | // Whitespace 20 | Whitespace 21 | // comment node 22 | Comment 23 | // multiline comment node 24 | MultilineComment 25 | // = operator 26 | Eq 27 | // != or <> operator 28 | Neq 29 | // < operator 30 | Lt 31 | // > operator 32 | Gt 33 | // <= operator 34 | LtEq 35 | // >= operator 36 | GtEq 37 | // + operator 38 | Plus 39 | // - operator 40 | Minus 41 | // * operator 42 | Mult 43 | // / operator 44 | Div 45 | // % operator 46 | Caret 47 | // ^ operator 48 | Mod 49 | // Left parenthesis `(` 50 | LParen 51 | // Right parenthesis `)` 52 | RParen 53 | // Period 54 | Period 55 | // Colon 56 | Colon 57 | // DoubleColon 58 | DoubleColon 59 | // Semicolon 60 | Semicolon 61 | // Backslash 62 | Backslash 63 | // Left bracket `]` 64 | LBracket 65 | // Right bracket `[` 66 | RBracket 67 | // & 68 | Ampersand 69 | // Left brace `{` 70 | LBrace 71 | // Right brace `}` 72 | RBrace 73 | // ILLEGAL sqltoken 74 | ILLEGAL 75 | ) 76 | -------------------------------------------------------------------------------- /token/kind_string.go: -------------------------------------------------------------------------------- 1 | // Code generated by "stringer -type Kind kind.go"; DO NOT EDIT. 2 | 3 | package token 4 | 5 | import "strconv" 6 | 7 | func _() { 8 | // An "invalid array index" compiler error signifies that the constant values have changed. 9 | // Re-run the stringer command to generate them again. 10 | var x [1]struct{} 11 | _ = x[SQLKeyword-0] 12 | _ = x[Number-1] 13 | _ = x[Char-2] 14 | _ = x[SingleQuotedString-3] 15 | _ = x[NationalStringLiteral-4] 16 | _ = x[Comma-5] 17 | _ = x[Whitespace-6] 18 | _ = x[Comment-7] 19 | _ = x[MultilineComment-8] 20 | _ = x[Eq-9] 21 | _ = x[Neq-10] 22 | _ = x[Lt-11] 23 | _ = x[Gt-12] 24 | _ = x[LtEq-13] 25 | _ = x[GtEq-14] 26 | _ = x[Plus-15] 27 | _ = x[Minus-16] 28 | _ = x[Mult-17] 29 | _ = x[Div-18] 30 | _ = x[Caret-19] 31 | _ = x[Mod-20] 32 | _ = x[LParen-21] 33 | _ = x[RParen-22] 34 | _ = x[Period-23] 35 | _ = x[Colon-24] 36 | _ = x[DoubleColon-25] 37 | _ = x[Semicolon-26] 38 | _ = x[Backslash-27] 39 | _ = x[LBracket-28] 40 | _ = x[RBracket-29] 41 | _ = x[Ampersand-30] 42 | _ = x[LBrace-31] 43 | _ = x[RBrace-32] 44 | _ = x[ILLEGAL-33] 45 | } 46 | 47 | const _Kind_name = "SQLKeywordNumberCharSingleQuotedStringNationalStringLiteralCommaWhitespaceCommentMultilineCommentEqNeqLtGtLtEqGtEqPlusMinusMultDivCaretModLParenRParenPeriodColonDoubleColonSemicolonBackslashLBracketRBracketAmpersandLBraceRBraceILLEGAL" 48 | 49 | var _Kind_index = [...]uint8{0, 10, 16, 20, 38, 59, 64, 74, 81, 97, 99, 102, 104, 106, 110, 114, 118, 123, 127, 130, 135, 138, 144, 150, 156, 161, 172, 181, 190, 198, 206, 215, 221, 227, 234} 50 | 51 | func (i Kind) String() string { 52 | if i < 0 || i >= Kind(len(_Kind_index)-1) { 53 | return "Kind(" + strconv.FormatInt(int64(i), 10) + ")" 54 | } 55 | return _Kind_name[_Kind_index[i]:_Kind_index[i+1]] 56 | } 57 | --------------------------------------------------------------------------------