├── .github ├── CODEOWNERS └── workflows │ ├── release-binary.yml │ ├── release.yaml │ └── test.yml ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── astra ├── bundle.go ├── bundle_test.go ├── endpoint.go └── endpoint_test.go ├── codecs ├── codec.go ├── partial_codecs.go ├── partial_codecs_test.go ├── reader.go └── types.go ├── cql-proxy.png ├── go.mod ├── go.sum ├── k8s ├── cql-proxy-configmap.yml └── cql-proxy.yml ├── parser ├── bench │ └── bench_parser.go ├── identifier.go ├── lexer.go ├── lexer.rl ├── lexer_test.go ├── metadata.go ├── parse_batch.go ├── parse_batch_test.go ├── parse_delete.go ├── parse_delete_test.go ├── parse_insert.go ├── parse_insert_test.go ├── parse_relation.go ├── parse_relation_test.go ├── parse_select.go ├── parse_term.go ├── parse_term_test.go ├── parse_update.go ├── parse_update_test.go ├── parse_updateop.go ├── parse_updateop_test.go ├── parser.go ├── parser_test.go └── parser_utils.go ├── proxy.go ├── proxy ├── proxy.go ├── proxy_retries_test.go ├── proxy_test.go ├── request.go ├── retrypolicy.go ├── retrypolicy_test.go ├── run.go └── run_test.go └── proxycore ├── auth.go ├── clientconn.go ├── clientconn_test.go ├── cluster.go ├── cluster_test.go ├── conn.go ├── conn_test.go ├── connpool.go ├── connpool_test.go ├── endpoint.go ├── endpoint_test.go ├── errors.go ├── host.go ├── lb.go ├── lb_test.go ├── log.go ├── mockcluster.go ├── reconnpolicy.go ├── reconnpolicy_test.go ├── requests.go ├── requests_test.go ├── resultset.go ├── session.go └── session_test.go /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @mpenick @dougwettlaufer 2 | -------------------------------------------------------------------------------- /.github/workflows/release-binary.yml: -------------------------------------------------------------------------------- 1 | name: Release Binaries 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | tags: 7 | - 'v*.*.*' 8 | 9 | jobs: 10 | build: 11 | name: Build and Upload Release Assets 12 | runs-on: ubuntu-latest 13 | container: golang:1.24.2-bullseye 14 | strategy: 15 | matrix: 16 | goosarch: 17 | - "linux/amd64" 18 | - "linux/arm64" 19 | - "windows/amd64" 20 | - "darwin/amd64" 21 | - "darwin/arm64" 22 | env: 23 | GO111MODULE: on 24 | CGO_ENABLED: 0 25 | steps: 26 | - name: Checkout code 27 | uses: actions/checkout@v2 28 | - name: Build ${{ matrix.goosarch }} binary 29 | run: | 30 | apt update 31 | apt -y install zip 32 | 33 | export GOOSARCH=${{ matrix.goosarch }} 34 | export GOOS=${GOOSARCH%/*} 35 | export GOARCH=${GOOSARCH#*/} 36 | 37 | mkdir -p artifacts 38 | 39 | if [ "$GOOS" = "windows" ]; then 40 | go build -o cql-proxy.exe 41 | zip -vr cql-proxy-${GOOS}-${GOARCH}-${{ github.ref_name }}.zip cql-proxy.exe LICENSE 42 | sha256sum cql-proxy-${GOOS}-${GOARCH}-${{ github.ref_name }}.zip | cut -d ' ' -f 1 > cql-proxy-${GOOS}-${GOARCH}-${{ github.ref_name }}-sha256.txt 43 | else 44 | go build -o cql-proxy 45 | tar cvfz cql-proxy-${GOOS}-${GOARCH}-${{ github.ref_name }}.tgz cql-proxy LICENSE 46 | sha256sum cql-proxy-${GOOS}-${GOARCH}-${{ github.ref_name }}.tgz | cut -d ' ' -f 1 > cql-proxy-${GOOS}-${GOARCH}-${{ github.ref_name }}-sha256.txt 47 | fi 48 | 49 | mv cql-proxy-* artifacts 50 | - name: Upload ${{ matrix.goosarch }} binaries 51 | uses: softprops/action-gh-release@v1 52 | with: 53 | name: ${{ github.ref_name }} 54 | files: | 55 | artifacts/* 56 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Docker Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*.*.*' 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@master 13 | - 14 | name: Set version 15 | id: vars 16 | run: echo ::set-output name=tag::${GITHUB_REF#refs/*/} 17 | - 18 | name: Set up QEMU 19 | uses: docker/setup-qemu-action@v1 20 | - 21 | name: Set up Docker Buildx 22 | id: buildx 23 | uses: docker/setup-buildx-action@v1 24 | with: 25 | install: true 26 | - 27 | name: Login to DockerHub 28 | uses: docker/login-action@v1 29 | with: 30 | username: ${{ secrets.DOCKER_USERNAME }} 31 | password: ${{ secrets.DOCKER_PASSWORD }} 32 | - 33 | name: Build and push 34 | id: docker_build 35 | uses: docker/build-push-action@v2 36 | with: 37 | push: true 38 | platforms: linux/amd64,linux/arm64 39 | tags: datastax/cql-proxy:${{ steps.vars.outputs.tag }} 40 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | on: [push] 2 | name: Test 3 | jobs: 4 | test: 5 | strategy: 6 | matrix: 7 | go-version: [1.24.2] 8 | os: [ubuntu-latest, macos-latest] 9 | runs-on: ${{ matrix.os }} 10 | steps: 11 | - name: Add loopback aliases 12 | if: matrix.os == 'macos-latest' 13 | shell: bash 14 | run: | 15 | sudo ifconfig lo0 alias 127.0.0.2 up 16 | sudo ifconfig lo0 alias 127.0.0.3 up 17 | sudo ifconfig lo0 alias 127.0.0.4 up 18 | sudo ifconfig lo0 alias 127.0.0.5 up 19 | sudo ifconfig lo0 alias 127.0.0.6 up 20 | - name: Install Go 21 | uses: actions/setup-go@v2 22 | with: 23 | go-version: ${{ matrix.go-version }} 24 | - name: Checkout code 25 | uses: actions/checkout@v2 26 | - name: Test 27 | run: go test -v ./... 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.zip 3 | cql-proxy 4 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.24.2 as builder 2 | 3 | # Disable cgo to remove gcc dependency 4 | ENV CGO_ENABLED=0 5 | 6 | WORKDIR /go/src/cql-proxy 7 | 8 | # Grab the dependencies 9 | COPY go.mod go.sum ./ 10 | RUN go mod download 11 | 12 | # Copy in source 13 | COPY . ./ 14 | 15 | # Build and install binary 16 | RUN go install github.com/datastax/cql-proxy 17 | 18 | # Run unit tests 19 | RUN go test -short -v ./... 20 | 21 | # a new clean image with just the binary 22 | FROM alpine:3.14 23 | RUN apk add --no-cache ca-certificates 24 | 25 | EXPOSE 9042 26 | 27 | # Copy in the binary 28 | COPY --from=builder /go/bin/cql-proxy . 29 | 30 | ENTRYPOINT ["/cql-proxy"] 31 | -------------------------------------------------------------------------------- /astra/bundle.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package astra 16 | 17 | import ( 18 | "archive/zip" 19 | "bytes" 20 | "context" 21 | "crypto/tls" 22 | "crypto/x509" 23 | "encoding/json" 24 | "errors" 25 | "fmt" 26 | "io" 27 | "io/ioutil" 28 | "net/http" 29 | "runtime" 30 | "time" 31 | 32 | "github.com/datastax/astra-client-go/v2/astra" 33 | ) 34 | 35 | type Bundle struct { 36 | TLSConfig *tls.Config 37 | Host string 38 | Port int 39 | } 40 | 41 | func LoadBundleZip(reader *zip.Reader) (*Bundle, error) { 42 | contents, err := extract(reader) 43 | if err != nil { 44 | return nil, err 45 | } 46 | 47 | config := struct { 48 | Host string `json:"host"` 49 | Port int `json:"port"` 50 | }{} 51 | err = json.Unmarshal(contents["config.json"], &config) 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | rootCAs, err := createCertPool() 57 | if err != nil { 58 | return nil, err 59 | } 60 | 61 | ok := rootCAs.AppendCertsFromPEM(contents["ca.crt"]) 62 | if !ok { 63 | return nil, fmt.Errorf("the provided CA cert could not be added to the root CA pool") 64 | } 65 | 66 | cert, err := tls.X509KeyPair(contents["cert"], contents["key"]) 67 | if err != nil { 68 | return nil, err 69 | } 70 | 71 | return &Bundle{ 72 | TLSConfig: &tls.Config{ 73 | RootCAs: rootCAs, 74 | Certificates: []tls.Certificate{cert}, 75 | ServerName: config.Host, 76 | }, 77 | Host: config.Host, 78 | Port: config.Port, 79 | }, nil 80 | } 81 | 82 | func LoadBundleZipFromPath(path string) (*Bundle, error) { 83 | reader, err := zip.OpenReader(path) 84 | if err != nil { 85 | return nil, err 86 | } 87 | 88 | defer func(reader *zip.ReadCloser) { 89 | _ = reader.Close() 90 | }(reader) 91 | 92 | return LoadBundleZip(&reader.Reader) 93 | } 94 | 95 | func LoadBundleZipFromURL(url, databaseID, token string, timeout time.Duration) (*Bundle, error) { 96 | ctx, cancel := context.WithTimeout(context.Background(), timeout) 97 | defer cancel() 98 | 99 | credsURL, err := generateSecureBundleURLWithResponse(url, databaseID, token, ctx) 100 | if err != nil { 101 | return nil, fmt.Errorf("error generating secure bundle zip URLs: %v", err) 102 | } 103 | 104 | resp, err := http.Get(credsURL.DownloadURL) 105 | if err != nil { 106 | return nil, fmt.Errorf("error downloading secure bundle zip: %v", err) 107 | } 108 | 109 | defer resp.Body.Close() 110 | 111 | body, err := readAllWithTimeout(resp.Body, ctx) 112 | if err != nil { 113 | return nil, fmt.Errorf("error reading downloaded secure bundle zip: %v", err) 114 | } 115 | 116 | reader, err := zip.NewReader(bytes.NewReader(body), int64(len(body))) 117 | if err != nil { 118 | return nil, fmt.Errorf("error creating zip reader for secure bundle zip: %v", err) 119 | } 120 | 121 | return LoadBundleZip(reader) 122 | } 123 | 124 | func readAllWithTimeout(r io.Reader, ctx context.Context) (bytes []byte, err error) { 125 | ch := make(chan struct{}) 126 | 127 | go func() { 128 | bytes, err = ioutil.ReadAll(r) 129 | close(ch) 130 | }() 131 | 132 | select { 133 | case <-ch: 134 | case <-ctx.Done(): 135 | return nil, errors.New("timeout reading data") 136 | } 137 | 138 | return bytes, err 139 | } 140 | 141 | func generateSecureBundleURLWithResponse(url, databaseID, token string, ctx context.Context) (*astra.CredsURL, error) { 142 | client, err := astra.NewClientWithResponses(url, func(c *astra.Client) error { 143 | c.RequestEditors = append(c.RequestEditors, func(ctx context.Context, req *http.Request) error { 144 | req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) 145 | return nil 146 | }) 147 | return nil 148 | }) 149 | if err != nil { 150 | return nil, err 151 | } 152 | // return all bundles, as returning only one causes issue with response deserializing 153 | // client code generated by astra-client-go from Swagger definition does not support 'oneOf' clause 154 | // used as a response of /v2/databases/{databaseID}/secureBundleURL endpoint 155 | // (https://github.com/oapi-codegen/oapi-codegen/issues/1665) 156 | returnAllBundles := true 157 | res, err := client.GenerateSecureBundleURLWithResponse(ctx, databaseID, &astra.GenerateSecureBundleURLParams{All: &returnAllBundles}) 158 | if err != nil { 159 | return nil, fmt.Errorf("error generating bundle urls: %v", err) 160 | } 161 | 162 | if res.StatusCode() != http.StatusOK { 163 | return nil, fmt.Errorf("unable to generate bundle urls, failed with status code %d", res.StatusCode()) 164 | } 165 | 166 | return &(*res.JSON200)[0], nil 167 | } 168 | 169 | func extract(reader *zip.Reader) (map[string][]byte, error) { 170 | contents := make(map[string][]byte) 171 | 172 | for _, file := range reader.File { 173 | switch file.Name { 174 | case "config.json", "cert", "key", "ca.crt": 175 | bytes, err := loadBytes(file) 176 | if err != nil { 177 | return nil, err 178 | } 179 | contents[file.Name] = bytes 180 | } 181 | } 182 | 183 | for _, file := range []string{"config.json", "cert", "key", "ca.crt"} { 184 | if _, ok := contents[file]; !ok { 185 | return nil, fmt.Errorf("bundle missing '%s' file", file) 186 | } 187 | } 188 | 189 | return contents, nil 190 | } 191 | 192 | func loadBytes(file *zip.File) ([]byte, error) { 193 | r, err := file.Open() 194 | if err != nil { 195 | return nil, err 196 | } 197 | defer func(r io.ReadCloser) { 198 | _ = r.Close() 199 | }(r) 200 | return ioutil.ReadAll(r) 201 | } 202 | 203 | func createCertPool() (*x509.CertPool, error) { 204 | ca, err := x509.SystemCertPool() 205 | if err != nil && runtime.GOOS == "windows" { 206 | return x509.NewCertPool(), nil 207 | } 208 | return ca, err 209 | } 210 | -------------------------------------------------------------------------------- /astra/endpoint.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package astra 16 | 17 | import ( 18 | "context" 19 | "crypto/tls" 20 | "crypto/x509" 21 | "encoding/json" 22 | "errors" 23 | "fmt" 24 | "net/http" 25 | "sync" 26 | "time" 27 | 28 | "github.com/datastax/cql-proxy/proxycore" 29 | ) 30 | 31 | type astraResolver struct { 32 | sniProxyAddress string 33 | region string 34 | bundle *Bundle 35 | mu *sync.Mutex 36 | timeout time.Duration 37 | } 38 | 39 | type astraEndpoint struct { 40 | addr string 41 | key string 42 | tlsConfig *tls.Config 43 | } 44 | 45 | func NewResolver(bundle *Bundle, timeout time.Duration) proxycore.EndpointResolver { 46 | return &astraResolver{ 47 | bundle: bundle, 48 | mu: &sync.Mutex{}, 49 | timeout: timeout, 50 | } 51 | } 52 | 53 | func (r *astraResolver) Resolve(ctx context.Context) ([]proxycore.Endpoint, error) { 54 | var metadata *astraMetadata 55 | 56 | ctx, cancel := context.WithTimeout(ctx, r.timeout) 57 | defer cancel() 58 | 59 | httpsClient := &http.Client{ 60 | Transport: &http.Transport{ 61 | TLSClientConfig: r.bundle.TLSConfig.Clone(), 62 | }, 63 | } 64 | 65 | url := fmt.Sprintf("https://%s:%d/metadata", r.bundle.Host, r.bundle.Port) 66 | req, err := http.NewRequestWithContext(ctx, "GET", url, http.NoBody) 67 | if err != nil { 68 | return nil, err 69 | } 70 | 71 | response, err := httpsClient.Do(req) 72 | if err != nil { 73 | return nil, fmt.Errorf("unable to get metadata from %s: %w", url, err) 74 | } 75 | 76 | body, err := readAllWithTimeout(response.Body, ctx) 77 | if err != nil { 78 | return nil, err 79 | } 80 | 81 | err = json.Unmarshal(body, &metadata) 82 | if err != nil { 83 | return nil, err 84 | } 85 | 86 | sniProxyAddress := metadata.ContactInfo.SniProxyAddress 87 | 88 | r.mu.Lock() 89 | r.sniProxyAddress = sniProxyAddress 90 | r.region = metadata.Region 91 | r.mu.Unlock() 92 | 93 | var endpoints []proxycore.Endpoint 94 | for _, cp := range metadata.ContactInfo.ContactPoints { 95 | endpoints = append(endpoints, &astraEndpoint{ 96 | addr: sniProxyAddress, 97 | key: fmt.Sprintf("%s:%s", sniProxyAddress, cp), 98 | tlsConfig: copyTLSConfig(r.bundle, cp), 99 | }) 100 | } 101 | 102 | return endpoints, nil 103 | } 104 | 105 | func (r *astraResolver) getSNIProxyAddressAndRegion() (string, string, error) { 106 | r.mu.Lock() 107 | defer r.mu.Unlock() 108 | if len(r.sniProxyAddress) == 0 { 109 | return "", "", errors.New("SNI proxy address (and region) never resolved") 110 | } 111 | return r.sniProxyAddress, r.region, nil 112 | } 113 | 114 | func (r *astraResolver) NewEndpoint(row proxycore.Row) (proxycore.Endpoint, error) { 115 | sniProxyAddress, region, err := r.getSNIProxyAddressAndRegion() 116 | if err != nil { 117 | return nil, err 118 | } 119 | dc, err := row.StringByName("data_center") 120 | if err != nil { 121 | return nil, err 122 | } 123 | if len(region) > 0 && region != dc { 124 | return nil, proxycore.IgnoreEndpoint 125 | } 126 | hostId, err := row.UUIDByName("host_id") 127 | if err != nil { 128 | return nil, err 129 | } else { 130 | return &astraEndpoint{ 131 | addr: sniProxyAddress, 132 | key: fmt.Sprintf("%s:%s", sniProxyAddress, &hostId), 133 | tlsConfig: copyTLSConfig(r.bundle, hostId.String()), 134 | }, nil 135 | } 136 | } 137 | 138 | func (a astraEndpoint) String() string { 139 | return a.Key() 140 | } 141 | 142 | func (a astraEndpoint) Key() string { 143 | return a.key 144 | } 145 | 146 | func (a astraEndpoint) Addr() string { 147 | return a.addr 148 | } 149 | 150 | func (a astraEndpoint) IsResolved() bool { 151 | return false 152 | } 153 | 154 | func (a astraEndpoint) TLSConfig() *tls.Config { 155 | return a.tlsConfig 156 | } 157 | 158 | func copyTLSConfig(bundle *Bundle, serverName string) *tls.Config { 159 | tlsConfig := bundle.TLSConfig.Clone() 160 | tlsConfig.ServerName = serverName 161 | tlsConfig.InsecureSkipVerify = true 162 | tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 163 | certs := make([]*x509.Certificate, len(rawCerts)) 164 | for i, asn1Data := range rawCerts { 165 | cert, err := x509.ParseCertificate(asn1Data) 166 | if err != nil { 167 | return errors.New("tls: failed to parse certificate from server: " + err.Error()) 168 | } 169 | certs[i] = cert 170 | } 171 | 172 | opts := x509.VerifyOptions{ 173 | Roots: tlsConfig.RootCAs, 174 | CurrentTime: time.Now(), 175 | DNSName: bundle.Host, 176 | Intermediates: x509.NewCertPool(), 177 | } 178 | for _, cert := range certs[1:] { 179 | opts.Intermediates.AddCert(cert) 180 | } 181 | var err error 182 | verifiedChains, err = certs[0].Verify(opts) 183 | return err 184 | } 185 | return tlsConfig 186 | } 187 | 188 | type contactInfo struct { 189 | TypeName string `json:"type"` 190 | LocalDc string `json:"local_dc"` 191 | SniProxyAddress string `json:"sni_proxy_address"` 192 | ContactPoints []string `json:"contact_points"` 193 | } 194 | 195 | type astraMetadata struct { 196 | Version int `json:"version"` 197 | Region string `json:"region"` 198 | ContactInfo contactInfo `json:"contact_info"` 199 | } 200 | -------------------------------------------------------------------------------- /astra/endpoint_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package astra 16 | 17 | import ( 18 | "context" 19 | "crypto/tls" 20 | "encoding/json" 21 | "errors" 22 | "net" 23 | "net/http" 24 | "os" 25 | "testing" 26 | "time" 27 | 28 | "github.com/datastax/cql-proxy/codecs" 29 | "github.com/datastax/cql-proxy/proxycore" 30 | "github.com/datastax/go-cassandra-native-protocol/datatype" 31 | "github.com/datastax/go-cassandra-native-protocol/message" 32 | "github.com/datastax/go-cassandra-native-protocol/primitive" 33 | "github.com/stretchr/testify/assert" 34 | "github.com/stretchr/testify/require" 35 | ) 36 | 37 | const sniProxyAddr = "localhost:8080" 38 | 39 | var contactPoints = []string{ 40 | "a2e24181-d732-402a-ab06-894a8b2f6094", 41 | "ce00ba58-a377-4022-ba09-00394ee66cfb", 42 | "9e339fe3-2bf2-45ce-a660-76951f39a8e8", 43 | } 44 | 45 | func TestMain(m *testing.M) { 46 | serv, err := runTestMetaSvcAsync(sniProxyAddr, contactPoints) 47 | if err != nil { 48 | panic(err) 49 | } 50 | r := m.Run() 51 | _ = serv.Close() 52 | os.Exit(r) 53 | } 54 | 55 | func TestAstraResolver_Resolve(t *testing.T) { 56 | resolver := createResolver(t) 57 | endpoints, err := resolver.Resolve(context.Background()) 58 | require.NoError(t, err) 59 | 60 | for _, endpoint := range endpoints { 61 | assert.False(t, endpoint.IsResolved()) 62 | assert.Equal(t, sniProxyAddr, endpoint.Addr()) 63 | assert.Contains(t, contactPoints, endpoint.TLSConfig().ServerName) 64 | } 65 | } 66 | 67 | func TestAstraResolver_NewEndpoint(t *testing.T) { 68 | resolver := createResolver(t) 69 | _, err := resolver.Resolve(context.Background()) 70 | require.NoError(t, err) 71 | 72 | const hostId = "a2e24181-d732-402a-ab06-894a8b2f6094" 73 | 74 | rs := proxycore.NewResultSet(&message.RowsResult{ 75 | Metadata: &message.RowsMetadata{ 76 | ColumnCount: 1, 77 | Columns: []*message.ColumnMetadata{ 78 | { 79 | Keyspace: "system", 80 | Table: "peers", 81 | Name: "host_id", 82 | Index: 0, 83 | Type: datatype.Uuid, 84 | }, 85 | { 86 | Keyspace: "system", 87 | Table: "peers", 88 | Name: "data_center", 89 | Index: 1, 90 | Type: datatype.Varchar, 91 | }, 92 | }, 93 | }, 94 | Data: message.RowSet{ 95 | message.Row{makeUUID(hostId), makeVarchar("us-east1")}, 96 | }, 97 | }, primitive.ProtocolVersion4) 98 | 99 | endpoint, err := resolver.NewEndpoint(rs.Row(0)) 100 | assert.NotNil(t, endpoint) 101 | assert.Nil(t, err) 102 | assert.Contains(t, endpoint.Key(), hostId) 103 | } 104 | 105 | func TestAstraResolver_NewEndpoint_Ignored(t *testing.T) { 106 | resolver := createResolver(t) 107 | _, err := resolver.Resolve(context.Background()) 108 | require.NoError(t, err) 109 | 110 | const hostId = "a2e24181-d732-402a-ab06-894a8b2f6094" 111 | 112 | rs := proxycore.NewResultSet(&message.RowsResult{ 113 | Metadata: &message.RowsMetadata{ 114 | ColumnCount: 1, 115 | Columns: []*message.ColumnMetadata{ 116 | { 117 | Keyspace: "system", 118 | Table: "peers", 119 | Name: "host_id", 120 | Index: 0, 121 | Type: datatype.Uuid, 122 | }, 123 | { 124 | Keyspace: "system", 125 | Table: "peers", 126 | Name: "data_center", 127 | Index: 1, 128 | Type: datatype.Varchar, 129 | }, 130 | }, 131 | }, 132 | Data: message.RowSet{ 133 | message.Row{makeUUID(hostId), makeVarchar("ignored")}, 134 | }, 135 | }, primitive.ProtocolVersion4) 136 | 137 | endpoint, err := resolver.NewEndpoint(rs.Row(0)) 138 | assert.Nil(t, endpoint) 139 | assert.ErrorIs(t, err, proxycore.IgnoreEndpoint) 140 | } 141 | 142 | func TestAstraResolver_NewEndpointInvalidHostID(t *testing.T) { 143 | resolver := createResolver(t) 144 | _, err := resolver.Resolve(context.Background()) 145 | require.NoError(t, err) 146 | 147 | rs := proxycore.NewResultSet(&message.RowsResult{ 148 | Metadata: &message.RowsMetadata{ 149 | ColumnCount: 1, 150 | Columns: []*message.ColumnMetadata{ 151 | { 152 | Keyspace: "system", 153 | Table: "peers", 154 | Name: "host_id", 155 | Index: 0, 156 | Type: datatype.Uuid, 157 | }, 158 | }, 159 | }, 160 | Data: message.RowSet{ 161 | message.Row{nil}, // Null value 162 | }, 163 | }, primitive.ProtocolVersion4) 164 | 165 | endpoint, err := resolver.NewEndpoint(rs.Row(0)) 166 | assert.Nil(t, endpoint) 167 | assert.Error(t, err, "ignoring host because its `host_id` is not set or is invalid") 168 | } 169 | 170 | func TestAstraResolver_Timeout(t *testing.T) { 171 | ctx, cancel := context.WithTimeout(context.Background(), 1) // Very short timeout 172 | defer cancel() 173 | 174 | resolver := createResolver(t) 175 | _, err := resolver.Resolve(ctx) 176 | assert.ErrorIs(t, err, context.DeadlineExceeded) // Expect a timeout 177 | } 178 | 179 | func createResolver(t *testing.T) proxycore.EndpointResolver { 180 | path, err := writeBundle("127.0.0.1", 8080) 181 | require.NoError(t, err) 182 | 183 | bundle, err := LoadBundleZipFromPath(path) 184 | require.NoError(t, err) 185 | 186 | return NewResolver(bundle, 10*time.Second) 187 | } 188 | 189 | func runTestMetaSvcAsync(sniProxyAddr string, contactPoints []string) (*http.Server, error) { 190 | host, _, err := net.SplitHostPort(sniProxyAddr) 191 | if err != nil { 192 | return nil, err 193 | } 194 | 195 | tlsConfig, err := createServerTLSConfig(host) 196 | if err != nil { 197 | return nil, err 198 | } 199 | 200 | listener, err := net.Listen("tcp", sniProxyAddr) 201 | if err != nil { 202 | return nil, err 203 | } 204 | 205 | mux := http.NewServeMux() 206 | mux.HandleFunc("/metadata", func(writer http.ResponseWriter, request *http.Request) { 207 | res, err := json.Marshal(astraMetadata{ 208 | Version: 1, 209 | Region: "us-east1", 210 | ContactInfo: contactInfo{ 211 | SniProxyAddress: sniProxyAddr, 212 | ContactPoints: contactPoints, 213 | }, 214 | }) 215 | if err != nil { 216 | writer.WriteHeader(500) 217 | } else { 218 | _, _ = writer.Write(res) 219 | } 220 | }) 221 | 222 | serv := &http.Server{ 223 | Addr: sniProxyAddr, 224 | TLSConfig: tlsConfig, 225 | Handler: mux, 226 | } 227 | 228 | go func() { 229 | _ = serv.ServeTLS(listener, "", "") 230 | }() 231 | 232 | return serv, nil 233 | } 234 | 235 | func createServerTLSConfig(dnsName string) (*tls.Config, error) { 236 | rootCAs, err := createCertPool() 237 | if err != nil { 238 | return nil, err 239 | } 240 | 241 | if !rootCAs.AppendCertsFromPEM(testCAPEM) { 242 | return nil, errors.New("unable to add cert to CA pool") 243 | } 244 | 245 | cert, err := tls.X509KeyPair(testCertPEM, testKeyPEM) 246 | if err != nil { 247 | return nil, err 248 | } 249 | 250 | return &tls.Config{ 251 | RootCAs: rootCAs, 252 | ClientCAs: rootCAs, 253 | Certificates: []tls.Certificate{cert}, 254 | ClientAuth: tls.RequireAndVerifyClientCert, 255 | }, nil 256 | } 257 | 258 | func makeUUID(uuid string) []byte { 259 | parsedUuid, _ := primitive.ParseUuid(uuid) 260 | bytes, _ := codecs.EncodeType(datatype.Uuid, primitive.ProtocolVersion4, parsedUuid) 261 | return bytes 262 | } 263 | 264 | func makeVarchar(s string) []byte { 265 | bytes, _ := codecs.EncodeType(datatype.Varchar, primitive.ProtocolVersion4, s) 266 | return bytes 267 | } 268 | -------------------------------------------------------------------------------- /codecs/codec.go: -------------------------------------------------------------------------------- 1 | package codecs 2 | 3 | import ( 4 | "github.com/datastax/go-cassandra-native-protocol/compression/lz4" 5 | "github.com/datastax/go-cassandra-native-protocol/compression/snappy" 6 | "github.com/datastax/go-cassandra-native-protocol/frame" 7 | "github.com/datastax/go-cassandra-native-protocol/message" 8 | ) 9 | 10 | var ( 11 | CustomMessageCodecs = []message.Codec{ 12 | &partialQueryCodec{}, &partialExecuteCodec{}, &partialBatchCodec{}, 13 | } 14 | 15 | CustomRawCodec = frame.NewRawCodec(CustomMessageCodecs...) 16 | 17 | CustomRawCodecsWithCompression = map[string]frame.RawCodec{ 18 | "lz4": frame.NewRawCodecWithCompression(&lz4.Compressor{}, CustomMessageCodecs...), 19 | "snappy": frame.NewRawCodecWithCompression(&snappy.Compressor{}, CustomMessageCodecs...), 20 | } 21 | 22 | DefaultRawCodec = frame.NewRawCodec() 23 | DefaultRawCodecsWithCompression = map[string]frame.RawCodec{ 24 | "lz4": frame.NewRawCodecWithCompression(&lz4.Compressor{}), 25 | "snappy": frame.NewRawCodecWithCompression(&snappy.Compressor{}), 26 | } 27 | 28 | CompressionNames = []string{"lz4", "snappy"} 29 | ) 30 | -------------------------------------------------------------------------------- /codecs/reader.go: -------------------------------------------------------------------------------- 1 | package codecs 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | ) 7 | 8 | // FrameBodyReader is an [io.Reader] that also contains a reference to the underlying bytes buffer for a frame body. 9 | // This is used to decode "partial" decode message types without requiring copying the underlying data for certain frame 10 | // fields. This can avoid extra allocations and copies. 11 | type FrameBodyReader struct { 12 | *bytes.Reader 13 | Body []byte 14 | } 15 | 16 | func NewFrameBodyReader(b []byte) *FrameBodyReader { 17 | return &FrameBodyReader{ 18 | Reader: bytes.NewReader(b), 19 | Body: b, 20 | } 21 | } 22 | 23 | func (r *FrameBodyReader) Position() int64 { 24 | pos, _ := r.Seek(0, io.SeekCurrent) // Doesn't fail 25 | return pos 26 | } 27 | 28 | func (r *FrameBodyReader) BytesSince(pos int64) []byte { 29 | return r.Body[pos:r.Position()] 30 | } 31 | 32 | func (r *FrameBodyReader) RemainingBytes() []byte { 33 | return r.Body[r.Position():] 34 | } 35 | -------------------------------------------------------------------------------- /codecs/types.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package codecs 16 | 17 | import ( 18 | "fmt" 19 | 20 | "github.com/datastax/go-cassandra-native-protocol/datacodec" 21 | "github.com/datastax/go-cassandra-native-protocol/datatype" 22 | "github.com/datastax/go-cassandra-native-protocol/primitive" 23 | ) 24 | 25 | var primitiveCodecs = map[datatype.DataType]datacodec.Codec{ 26 | datatype.Ascii: datacodec.Ascii, 27 | datatype.Bigint: datacodec.Bigint, 28 | datatype.Blob: datacodec.Blob, 29 | datatype.Boolean: datacodec.Boolean, 30 | datatype.Counter: datacodec.Counter, 31 | datatype.Decimal: datacodec.Decimal, 32 | datatype.Double: datacodec.Double, 33 | datatype.Float: datacodec.Float, 34 | datatype.Inet: datacodec.Inet, 35 | datatype.Int: datacodec.Int, 36 | datatype.Smallint: datacodec.Smallint, 37 | datatype.Varchar: datacodec.Varchar, 38 | datatype.Timeuuid: datacodec.Timeuuid, 39 | datatype.Tinyint: datacodec.Tinyint, 40 | datatype.Uuid: datacodec.Uuid, 41 | datatype.Varint: datacodec.Varint, 42 | } 43 | 44 | func EncodeType(dt datatype.DataType, version primitive.ProtocolVersion, val interface{}) ([]byte, error) { 45 | c, err := codecFromDataType(dt) 46 | if err != nil { 47 | return nil, err 48 | } 49 | return c.Encode(val, version) 50 | } 51 | 52 | func DecodeType(dt datatype.DataType, version primitive.ProtocolVersion, bytes []byte) (interface{}, error) { 53 | c, err := codecFromDataType(dt) 54 | if err != nil { 55 | return nil, err 56 | } 57 | var dest interface{} 58 | _, err = c.Decode(bytes, &dest, version) 59 | return dest, err 60 | } 61 | 62 | func codecFromDataType(dt datatype.DataType) (datacodec.Codec, error) { 63 | switch dt.Code() { 64 | case primitive.DataTypeCodeList: 65 | listType := dt.(*datatype.List) 66 | return datacodec.NewList(datatype.NewList(listType.ElementType)) 67 | case primitive.DataTypeCodeSet: 68 | setType := dt.(*datatype.Set) 69 | return datacodec.NewSet(datatype.NewSet(setType.ElementType)) 70 | case primitive.DataTypeCodeMap: 71 | mapType := dt.(*datatype.Map) 72 | return datacodec.NewMap(datatype.NewMap(mapType.KeyType, mapType.ValueType)) 73 | default: 74 | codec, ok := primitiveCodecs[dt] 75 | if !ok { 76 | return nil, fmt.Errorf("no codec for data type %v", dt) 77 | } 78 | return codec, nil 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /cql-proxy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastax/cql-proxy/0fb8063d2460022c71565d79eb83aa73b7a4c18a/cql-proxy.png -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/datastax/cql-proxy 2 | 3 | go 1.24.2 4 | 5 | require ( 6 | github.com/alecthomas/kong v0.2.17 7 | github.com/datastax/astra-client-go/v2 v2.2.54 8 | github.com/datastax/go-cassandra-native-protocol v0.0.0-20220706104457-5e8aad05cf90 9 | github.com/hashicorp/golang-lru v0.5.4 10 | github.com/stretchr/testify v1.8.1 11 | go.uber.org/atomic v1.8.0 12 | go.uber.org/zap v1.17.0 13 | gopkg.in/yaml.v2 v2.4.0 14 | ) 15 | 16 | require ( 17 | github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect 18 | github.com/davecgh/go-spew v1.1.1 // indirect 19 | github.com/deepmap/oapi-codegen v1.12.4 // indirect 20 | github.com/gocql/gocql v1.7.0 // indirect 21 | github.com/golang/snappy v0.0.3 // indirect 22 | github.com/google/uuid v1.3.0 // indirect 23 | github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect 24 | github.com/kr/text v0.2.0 // indirect 25 | github.com/pierrec/lz4/v4 v4.0.3 // indirect 26 | github.com/pkg/errors v0.9.1 // indirect 27 | github.com/pmezard/go-difflib v1.0.0 // indirect 28 | github.com/rs/zerolog v1.20.0 // indirect 29 | go.uber.org/multierr v1.7.0 // indirect 30 | gopkg.in/inf.v0 v0.9.1 // indirect 31 | gopkg.in/yaml.v3 v3.0.1 // indirect 32 | ) 33 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk= 2 | github.com/alecthomas/kong v0.2.17 h1:URDISCI96MIgcIlQyoCAlhOmrSw6pZScBNkctg8r0W0= 3 | github.com/alecthomas/kong v0.2.17/go.mod h1:ka3VZ8GZNPXv9Ov+j4YNLkI8mTuhXyr/0ktSlqIydQQ= 4 | github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ= 5 | github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk= 6 | github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= 7 | github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w= 8 | github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= 9 | github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= 10 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 11 | github.com/datastax/astra-client-go/v2 v2.2.54 h1:R2k9ek9zaU15cLD96np5gsj12oZhK3Z5/tSytjQagO8= 12 | github.com/datastax/astra-client-go/v2 v2.2.54/go.mod h1:zxXWuqDkYia7PzFIL3T7RmjChc9LN81UnfI2yB4kE7M= 13 | github.com/datastax/go-cassandra-native-protocol v0.0.0-20220706104457-5e8aad05cf90 h1:SiFe3gwoHPt95ly6HLjwyyItxROxCUJuxqqTnguR5ac= 14 | github.com/datastax/go-cassandra-native-protocol v0.0.0-20220706104457-5e8aad05cf90/go.mod h1:6FzirJfdffakAVqmHjwVfFkpru/gNbIazUOK5rIhndc= 15 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 16 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 17 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 18 | github.com/deepmap/oapi-codegen v1.12.4 h1:pPmn6qI9MuOtCz82WY2Xaw46EQjgvxednXXrP7g5Q2s= 19 | github.com/deepmap/oapi-codegen v1.12.4/go.mod h1:3lgHGMu6myQ2vqbbTXH2H1o4eXFTGnFiDaOaKKl5yas= 20 | github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus= 21 | github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4= 22 | github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= 23 | github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= 24 | github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= 25 | github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 26 | github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= 27 | github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= 28 | github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= 29 | github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= 30 | github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE= 31 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 32 | github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= 33 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 34 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 35 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 36 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 37 | github.com/pierrec/lz4/v4 v4.0.3 h1:vNQKSVZNYUEAvRY9FaUXAF1XPbSOHJtDTiP41kzDz2E= 38 | github.com/pierrec/lz4/v4 v4.0.3/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= 39 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 40 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 41 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 42 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 43 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 44 | github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= 45 | github.com/rs/zerolog v1.20.0 h1:38k9hgtUBdxFwE34yS8rTHmHBa4eN16E4DJlv177LNs= 46 | github.com/rs/zerolog v1.20.0/go.mod h1:IzD0RJ65iWH0w97OQQebJEvTZYvsCUm9WVLWBQrJRjo= 47 | github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0= 48 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 49 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 50 | github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= 51 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 52 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 53 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 54 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 55 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 56 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 57 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 58 | go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= 59 | go.uber.org/atomic v1.8.0 h1:CUhrE4N1rqSE6FM9ecihEjRkLQu8cDfgDyoOs83mEY4= 60 | go.uber.org/atomic v1.8.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= 61 | go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= 62 | go.uber.org/multierr v1.7.0 h1:zaiO/rmgFjbmCXdSYJWQcdvOCsthmdaHfr3Gm2Kx4Ec= 63 | go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= 64 | go.uber.org/zap v1.17.0 h1:MTjgFu6ZLKvY6Pvaqk97GlxNBuMpV4Hy/3P6tRGlI2U= 65 | go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo= 66 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 67 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 68 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 69 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 70 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 71 | golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 72 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 73 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 74 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= 75 | gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= 76 | gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= 77 | gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 78 | gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= 79 | gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= 80 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 81 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 82 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 83 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 84 | -------------------------------------------------------------------------------- /k8s/cql-proxy-configmap.yml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | binaryData: 3 | scb.zip: 4 | kind: ConfigMap 5 | metadata: 6 | creationTimestamp: "2021-09-01T22:15:24Z" 7 | name: config 8 | -------------------------------------------------------------------------------- /k8s/cql-proxy.yml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: cql-proxy-deployment 5 | labels: 6 | app: cql-proxy 7 | spec: 8 | selector: 9 | matchLabels: 10 | app: cql-proxy 11 | template: 12 | metadata: 13 | labels: 14 | app: cql-proxy 15 | spec: 16 | containers: 17 | - name: cql-proxy 18 | image: datastax/cql-proxy:v0.1.3 19 | command: ["./cql-proxy"] 20 | args: ["--astra-bundle=/tmp/scb.zip","--username=Client ID","--password=Client Secret"] 21 | volumeMounts: 22 | - name: my-cm-vol 23 | mountPath: /tmp/ 24 | ports: 25 | - containerPort: 9042 26 | name: cql-port 27 | protocol: TCP 28 | volumes: 29 | - name: my-cm-vol 30 | configMap: 31 | name: cql-proxy-configmap 32 | -------------------------------------------------------------------------------- /parser/identifier.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | import "strings" 18 | 19 | // Identifier is a CQL identifier 20 | type Identifier struct { 21 | id string 22 | ignoreCase bool 23 | } 24 | 25 | // IdentifierFromString creates an identifier from a string 26 | func IdentifierFromString(id string) Identifier { 27 | l := len(id) 28 | if l > 0 && id[0] == '"' { 29 | return Identifier{id: id[1 : l-1], ignoreCase: false} 30 | } else { 31 | return Identifier{id: id, ignoreCase: true} 32 | } 33 | } 34 | 35 | // correctly compares an identifier with a string 36 | func (i Identifier) equal(id string) bool { 37 | if i.ignoreCase { 38 | return strings.EqualFold(i.id, id) 39 | } else { 40 | return i.id == id 41 | } 42 | } 43 | 44 | func (i Identifier) isEmpty() bool { 45 | return len(i.id) == 0 46 | } 47 | 48 | func (i Identifier) String() string { 49 | if i.ignoreCase { 50 | return i.id 51 | } 52 | return "\"" + i.id + "\"" 53 | } 54 | 55 | func (i Identifier) ID() string { 56 | if i.ignoreCase { 57 | return strings.ToLower(i.id) 58 | } 59 | return strings.ReplaceAll(i.id, "\"\"", "\"") 60 | } 61 | -------------------------------------------------------------------------------- /parser/lexer.rl: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | type token int 18 | 19 | const ( 20 | tkInvalid token = iota 21 | tkEOF 22 | tkSelect 23 | tkInsert 24 | tkUpdate 25 | tkDelete 26 | tkBegin 27 | tkApply 28 | tkBatch 29 | tkCreate 30 | tkAlter 31 | tkDrop 32 | tkInto 33 | tkFrom 34 | tkUse 35 | tkUsing 36 | tkIf 37 | tkWhere 38 | tkAnd 39 | tkToken 40 | tkIs 41 | tkIn 42 | tkNot 43 | tkIdentifier 44 | tkStar 45 | tkComma 46 | tkDot 47 | tkColon 48 | tkQMark 49 | tkEqual 50 | tkAdd 51 | tkSub 52 | tkAddEqual 53 | tkSubEqual 54 | tkNotEqual 55 | tkGt 56 | tkLt 57 | tkLtEqual 58 | tkGtEqual 59 | tkLparen 60 | tkRparen 61 | tkLsquare 62 | tkRsquare 63 | tkLcurly 64 | tkRcurly 65 | tkInteger 66 | tkFloat 67 | tkBool 68 | tkNull 69 | tkStringLiteral 70 | tkHexNumber 71 | tkUuid 72 | tkDuration 73 | tkNan 74 | tkInfinity 75 | tkEOS 76 | ) 77 | 78 | const ( 79 | tkLangle=tkLt 80 | tkRangle=tkGt 81 | ) 82 | 83 | 84 | %%{ 85 | machine lex; 86 | write data; 87 | }%% 88 | 89 | type lexer struct { 90 | data string 91 | p, pe, m int 92 | id string 93 | } 94 | 95 | // initialize/reset lexer with data string to lex 96 | func (l *lexer) init(data string) { 97 | l.p, l.pe = 0, len(data) 98 | l.data = data 99 | } 100 | 101 | // mark the current lexer position 102 | func (l *lexer) mark() { 103 | l.m = l.p 104 | } 105 | 106 | // rewind position to the the previously marked position 107 | func (l *lexer) rewind() { 108 | l.p = l.m 109 | } 110 | 111 | // get the value of an identifier if that's the current token; otherwise, it's undefined 112 | func (l *lexer) identifier() Identifier { 113 | return IdentifierFromString(l.id) 114 | } 115 | 116 | // get the value of an identifier as a string if that's the current token; otherwise, it's undefined 117 | func (l *lexer) identifierStr() string { 118 | return l.id 119 | } 120 | 121 | // move to the next token 122 | func (l *lexer) next() token { 123 | data := l.data 124 | p, pe, eof := l.p, l.pe, l.pe 125 | act, ts, te, cs := 0, 0, 0, -1 126 | 127 | tk := tkInvalid 128 | 129 | if p == eof { 130 | return tkEOF 131 | } 132 | 133 | %%{ 134 | ws = [ \t]; 135 | nl = '\r\n' | '\n'; 136 | id = ([a-zA-Z][a-zA-Z0-9_]*)|("\"" ([^\r\n\"] | "\"\"")* "\""); 137 | integer = '-'? digit+; 138 | exponent = [eE] ('+' | '-')? digit+; 139 | float = (integer exponent) | (integer '.' digit* exponent?); 140 | string = '\'' ([^\'] | '\'\'')* '\''; 141 | pgstring = '$' ([^\$] | '$$')* '$'; 142 | hex = [a-f] | [A-F] | digit; 143 | hexnumber = '0' [xX] hex*; 144 | uuid = hex{8} '-' hex{4} '-' hex{4} '-' hex{4} '-' hex{12}; 145 | durationunit = /y/i | /mo/i | /w/i | /d/i | /h/i | /m/i | /s/i | /ms/i | /µs/i | /us/i | /ns/i; 146 | duration = ('-'? digit+ durationunit (digit+ durationunit)*) | 147 | ('-'? 'P' (digit+ 'Y')? (digit+ 'M')? (digit+ 'D')? ('T' (digit+ 'H')? (digit+ 'M')? (digit+ 'S')?)?) | 148 | ('-'? 'P' digit+ 'W') | 149 | '-'? 'P' digit digit digit digit '-' digit digit '-' digit digit 'T' digit digit ':' digit digit ':' digit digit; 150 | main := |* 151 | /select/i => { tk = tkSelect; fbreak; }; 152 | /insert/i => { tk = tkInsert; fbreak; }; 153 | /update/i => { tk = tkUpdate; fbreak; }; 154 | /delete/i => { tk = tkDelete; fbreak; }; 155 | /batch/i => { tk = tkBatch; fbreak; }; 156 | /begin/i => { tk = tkBegin; fbreak; }; 157 | /apply/i => { tk = tkApply; fbreak; }; 158 | /create/i => { tk = tkCreate; fbreak; }; 159 | /alter/i => { tk = tkAlter; fbreak; }; 160 | /drop/i => { tk = tkDrop; fbreak; }; 161 | /into/i => { tk = tkInto; fbreak; }; 162 | /from/i => { tk = tkFrom; fbreak; }; 163 | /use/i => { tk = tkUse; fbreak; }; 164 | /using/i => { tk = tkUsing; fbreak; }; 165 | /if/i => { tk = tkIf; fbreak; }; 166 | /where/i => { tk = tkWhere; fbreak; }; 167 | /and/i => { tk = tkAnd; fbreak; }; 168 | /is/i => { tk = tkIs; fbreak; }; 169 | /in/i => { tk = tkIn; fbreak; }; 170 | /not/i => { tk = tkNot; fbreak; }; 171 | /token/i => { tk = tkToken; fbreak; }; 172 | /true/i | /false/i => { tk = tkBool; fbreak; }; 173 | /null/i => { tk = tkNull; fbreak; }; 174 | '\*' => { tk = tkStar; fbreak; }; 175 | ',' => { tk = tkComma; fbreak; }; 176 | '\.' => { tk = tkDot; fbreak; }; 177 | ':' => { tk = tkColon; fbreak; }; 178 | '?' => { tk = tkQMark; fbreak; }; 179 | '(' => { tk = tkLparen; fbreak; }; 180 | ')' => { tk = tkRparen; fbreak; }; 181 | '[' => { tk = tkLsquare; fbreak; }; 182 | ']' => { tk = tkRsquare; fbreak; }; 183 | '{' => { tk = tkLcurly; fbreak; }; 184 | '}' => { tk = tkRcurly; fbreak; }; 185 | '=' => { tk = tkEqual; fbreak; }; 186 | '<=' => { tk = tkLtEqual; fbreak; }; 187 | '>=' => { tk = tkGtEqual; fbreak; }; 188 | '<' => { tk = tkLt; fbreak; }; 189 | '>' => { tk = tkGt; fbreak; }; 190 | '!=' => { tk = tkNotEqual; fbreak; }; 191 | '+' => { tk = tkAdd; fbreak; }; 192 | '-' => { tk = tkSub; fbreak; }; 193 | '+=' => { tk = tkAddEqual; fbreak; }; 194 | '-=' => { tk = tkSubEqual; fbreak; }; 195 | '-'? /nan/i => { tk = tkNan; fbreak; }; 196 | '-'? /infinity/i => { tk = tkInfinity; fbreak; }; 197 | ';' => { tk = tkEOS; fbreak; }; 198 | pgstring | string => { tk = tkStringLiteral; fbreak; }; 199 | integer => { tk = tkInteger; fbreak; }; 200 | float => { tk = tkFloat; fbreak; }; 201 | hexnumber => { tk = tkHexNumber; fbreak; }; 202 | duration => { tk = tkDuration; fbreak; }; 203 | uuid => { tk = tkUuid; fbreak; }; 204 | id => { tk = tkIdentifier; l.id = l.data[ts:te]; fbreak; }; 205 | nl => { /* Skip */ }; 206 | ws => { /* Skip */ }; 207 | any => { tk = tkInvalid; fbreak; }; 208 | *|; 209 | 210 | write init; 211 | write exec; 212 | }%% 213 | 214 | l.p = p 215 | 216 | if tk == tkInvalid && p == eof { 217 | return tkEOF 218 | } 219 | 220 | return tk 221 | } -------------------------------------------------------------------------------- /parser/lexer_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | import ( 18 | "fmt" 19 | "testing" 20 | 21 | "github.com/stretchr/testify/assert" 22 | ) 23 | 24 | func TestLexerNext(t *testing.T) { 25 | var l lexer 26 | l.init("SELECT * FROM system.local") 27 | 28 | assert.Equal(t, tkSelect, l.next()) 29 | assert.Equal(t, tkStar, l.next()) 30 | assert.Equal(t, tkFrom, l.next()) 31 | assert.Equal(t, tkIdentifier, l.next()) 32 | assert.Equal(t, tkDot, l.next()) 33 | assert.Equal(t, tkIdentifier, l.next()) 34 | assert.Equal(t, tkEOF, l.next()) 35 | } 36 | 37 | func TestLexerLiterals(t *testing.T) { 38 | var tests = []struct { 39 | literal string 40 | tk token 41 | }{ 42 | {"0", tkInteger}, 43 | {"1", tkInteger}, 44 | {"-1", tkInteger}, 45 | {"1.", tkFloat}, 46 | {"0.0", tkFloat}, 47 | {"-0.0", tkFloat}, 48 | {"-1.e9", tkFloat}, 49 | {"-1.e+0", tkFloat}, 50 | {"tRue", tkBool}, 51 | {"False", tkBool}, 52 | {"'a'", tkStringLiteral}, 53 | {"'abc'", tkStringLiteral}, 54 | {"''''", tkStringLiteral}, 55 | {"$a$", tkStringLiteral}, 56 | {"$abc$", tkStringLiteral}, 57 | {"$$$$", tkStringLiteral}, 58 | {"0x", tkHexNumber}, 59 | {"0x0", tkHexNumber}, 60 | {"0xabcdef", tkHexNumber}, 61 | {"123e4567-e89b-12d3-a456-426614174000", tkUuid}, 62 | {"nan", tkNan}, 63 | {"-NaN", tkNan}, 64 | {"-infinity", tkInfinity}, 65 | {"-Infinity", tkInfinity}, 66 | {"1Y", tkDuration}, 67 | {"1µs", tkDuration}, 68 | } 69 | 70 | for _, tt := range tests { 71 | var l lexer 72 | l.init(tt.literal) 73 | assert.Equal(t, tt.tk, l.next(), fmt.Sprintf("failed on literal: %s", tt.literal)) 74 | } 75 | } 76 | 77 | func TestLexerIdentifiers(t *testing.T) { 78 | var tests = []struct { 79 | literal string 80 | tk token 81 | expected string 82 | }{ 83 | {`system`, tkIdentifier, "system"}, 84 | {`sys"tem`, tkIdentifier, "sys"}, 85 | {`System`, tkIdentifier, "system"}, 86 | {`"system"`, tkIdentifier, "system"}, 87 | {`"system"`, tkIdentifier, "system"}, 88 | {`"System"`, tkIdentifier, "System"}, 89 | // below test verify correct escaping double quote character as per CQL definition: 90 | // identifier ::= unquoted_identifier | quoted_identifier 91 | // unquoted_identifier ::= re('[a-zA-Z][link:[a-zA-Z0-9]]*') 92 | // quoted_identifier ::= '"' (any character where " can appear if doubled)+ '"' 93 | {`""""`, tkIdentifier, "\""}, // outermost quotes indicate quoted string, inner two double quotes shall be treated as single quote 94 | {`""""""`, tkIdentifier, "\"\""}, // same as above, but 4 inner quotes result in 2 quotes 95 | {`"A"""""`, tkIdentifier, "A\"\""}, // outermost quotes indicate quoted string, 4 quotes after A result in 2 quotes 96 | {`"""A"""`, tkIdentifier, "\"A\""}, // outermost quotes indicate quoted string, 2 quotes before and after A result in single quotes 97 | {`"""""A"`, tkIdentifier, "\"\"A"}, // analogical to previous tests 98 | {`";`, tkInvalid, ""}, 99 | {`"""`, tkIdentifier, ""}, 100 | } 101 | 102 | for _, tt := range tests { 103 | var l lexer 104 | l.init(tt.literal) 105 | n := l.next() 106 | assert.Equal(t, tt.tk, n, fmt.Sprintf("failed on literal: %s", tt.literal)) 107 | if n == tkIdentifier { 108 | id := l.identifier() 109 | if id.ID() != tt.expected { 110 | t.Errorf("expected %s, got %s", tt.expected, l.id) 111 | } 112 | } 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /parser/metadata.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | import ( 18 | "github.com/datastax/go-cassandra-native-protocol/datatype" 19 | "github.com/datastax/go-cassandra-native-protocol/message" 20 | ) 21 | 22 | var ( 23 | SystemLocalColumns = []*message.ColumnMetadata{ 24 | {Keyspace: "system", Table: "local", Name: "key", Type: datatype.Varchar}, 25 | {Keyspace: "system", Table: "local", Name: "rpc_address", Type: datatype.Inet}, 26 | {Keyspace: "system", Table: "local", Name: "data_center", Type: datatype.Varchar}, 27 | {Keyspace: "system", Table: "local", Name: "rack", Type: datatype.Varchar}, 28 | {Keyspace: "system", Table: "local", Name: "tokens", Type: datatype.NewSet(datatype.Varchar)}, 29 | {Keyspace: "system", Table: "local", Name: "release_version", Type: datatype.Varchar}, 30 | {Keyspace: "system", Table: "local", Name: "partitioner", Type: datatype.Varchar}, 31 | {Keyspace: "system", Table: "local", Name: "cluster_name", Type: datatype.Varchar}, 32 | {Keyspace: "system", Table: "local", Name: "cql_version", Type: datatype.Varchar}, 33 | {Keyspace: "system", Table: "local", Name: "schema_version", Type: datatype.Uuid}, 34 | {Keyspace: "system", Table: "local", Name: "native_protocol_version", Type: datatype.Varchar}, 35 | {Keyspace: "system", Table: "local", Name: "host_id", Type: datatype.Uuid}, 36 | } 37 | 38 | DseSystemLocalColumns = []*message.ColumnMetadata{ 39 | {Keyspace: "system", Table: "local", Name: "key", Type: datatype.Varchar}, 40 | {Keyspace: "system", Table: "local", Name: "rpc_address", Type: datatype.Inet}, 41 | {Keyspace: "system", Table: "local", Name: "data_center", Type: datatype.Varchar}, 42 | // The column "dse_version" is important for some DSE advance workloads esp. for graph to determine the graph 43 | // language. 44 | {Keyspace: "system", Table: "local", Name: "dse_version", Type: datatype.Varchar}, // DSE only 45 | {Keyspace: "system", Table: "local", Name: "rack", Type: datatype.Varchar}, 46 | {Keyspace: "system", Table: "local", Name: "tokens", Type: datatype.NewSet(datatype.Varchar)}, 47 | {Keyspace: "system", Table: "local", Name: "release_version", Type: datatype.Varchar}, 48 | {Keyspace: "system", Table: "local", Name: "partitioner", Type: datatype.Varchar}, 49 | {Keyspace: "system", Table: "local", Name: "cluster_name", Type: datatype.Varchar}, 50 | {Keyspace: "system", Table: "local", Name: "cql_version", Type: datatype.Varchar}, 51 | {Keyspace: "system", Table: "local", Name: "schema_version", Type: datatype.Uuid}, 52 | {Keyspace: "system", Table: "local", Name: "native_protocol_version", Type: datatype.Varchar}, 53 | {Keyspace: "system", Table: "local", Name: "host_id", Type: datatype.Uuid}, 54 | } 55 | 56 | SystemPeersColumns = []*message.ColumnMetadata{ 57 | {Keyspace: "system", Table: "peers", Name: "peer", Type: datatype.Inet}, 58 | {Keyspace: "system", Table: "peers", Name: "rpc_address", Type: datatype.Inet}, 59 | {Keyspace: "system", Table: "peers", Name: "data_center", Type: datatype.Varchar}, 60 | {Keyspace: "system", Table: "peers", Name: "rack", Type: datatype.Varchar}, 61 | {Keyspace: "system", Table: "peers", Name: "tokens", Type: datatype.NewSet(datatype.Varchar)}, 62 | {Keyspace: "system", Table: "peers", Name: "release_version", Type: datatype.Varchar}, 63 | {Keyspace: "system", Table: "peers", Name: "schema_version", Type: datatype.Uuid}, 64 | {Keyspace: "system", Table: "peers", Name: "host_id", Type: datatype.Uuid}, 65 | } 66 | 67 | DseSystemPeersColumns = []*message.ColumnMetadata{ 68 | {Keyspace: "system", Table: "peers", Name: "peer", Type: datatype.Inet}, 69 | {Keyspace: "system", Table: "peers", Name: "rpc_address", Type: datatype.Inet}, 70 | {Keyspace: "system", Table: "peers", Name: "data_center", Type: datatype.Varchar}, 71 | {Keyspace: "system", Table: "peers", Name: "dse_version", Type: datatype.Varchar}, // DSE only 72 | {Keyspace: "system", Table: "peers", Name: "rack", Type: datatype.Varchar}, 73 | {Keyspace: "system", Table: "peers", Name: "tokens", Type: datatype.NewSet(datatype.Varchar)}, 74 | {Keyspace: "system", Table: "peers", Name: "release_version", Type: datatype.Varchar}, 75 | {Keyspace: "system", Table: "peers", Name: "schema_version", Type: datatype.Uuid}, 76 | {Keyspace: "system", Table: "peers", Name: "host_id", Type: datatype.Uuid}, 77 | } 78 | 79 | SystemSchemaKeyspaces = []*message.ColumnMetadata{ 80 | {Keyspace: "system", Table: "schema_keyspaces", Name: "keyspace_name", Type: datatype.Varchar}, 81 | {Keyspace: "system", Table: "schema_keyspaces", Name: "durable_writes", Type: datatype.Boolean}, 82 | {Keyspace: "system", Table: "schema_keyspaces", Name: "strategy_class", Type: datatype.Varchar}, 83 | {Keyspace: "system", Table: "schema_keyspaces", Name: "strategy_options", Type: datatype.Varchar}, 84 | } 85 | 86 | SystemSchemaColumnFamilies = []*message.ColumnMetadata{ 87 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "keyspace_name", Type: datatype.Varchar}, 88 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "columnfamily_name", Type: datatype.Varchar}, 89 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "bloom_filter_fp_chance", Type: datatype.Double}, 90 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "caching", Type: datatype.Varchar}, 91 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "cf_id", Type: datatype.Uuid}, 92 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "comment", Type: datatype.Varchar}, 93 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "compaction_strategy_class", Type: datatype.Varchar}, 94 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "compaction_strategy_options", Type: datatype.Varchar}, 95 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "comparator", Type: datatype.Varchar}, 96 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "compression_parameters", Type: datatype.Varchar}, 97 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "default_time_to_live", Type: datatype.Int}, 98 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "default_validator", Type: datatype.Varchar}, 99 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "dropped_columns", Type: datatype.NewMap(datatype.Varchar, datatype.Bigint)}, 100 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "gc_grace_seconds", Type: datatype.Int}, 101 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "is_dense", Type: datatype.Boolean}, 102 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "key_validator", Type: datatype.Varchar}, 103 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "local_read_repair_chance", Type: datatype.Double}, 104 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "max_compaction_threshold", Type: datatype.Int}, 105 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "max_index_interval", Type: datatype.Int}, 106 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "memtable_flush_period_in_ms", Type: datatype.Int}, 107 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "min_compaction_threshold", Type: datatype.Int}, 108 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "min_index_interval", Type: datatype.Int}, 109 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "read_repair_chance", Type: datatype.Double}, 110 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "speculative_retry", Type: datatype.Varchar}, 111 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "subcomparator", Type: datatype.Varchar}, 112 | {Keyspace: "system", Table: "schema_columnfamilies", Name: "type", Type: datatype.Varchar}, 113 | } 114 | 115 | SystemSchemaColumns = []*message.ColumnMetadata{ 116 | {Keyspace: "system", Table: "schema_columns", Name: "keyspace_name", Type: datatype.Varchar}, 117 | {Keyspace: "system", Table: "schema_columns", Name: "columnfamily_name", Type: datatype.Varchar}, 118 | {Keyspace: "system", Table: "schema_columns", Name: "column_name", Type: datatype.Varchar}, 119 | {Keyspace: "system", Table: "schema_columns", Name: "component_index", Type: datatype.Int}, 120 | {Keyspace: "system", Table: "schema_columns", Name: "index_name", Type: datatype.Varchar}, 121 | {Keyspace: "system", Table: "schema_columns", Name: "index_options", Type: datatype.Varchar}, 122 | {Keyspace: "system", Table: "schema_columns", Name: "index_type", Type: datatype.Varchar}, 123 | {Keyspace: "system", Table: "schema_columns", Name: "type", Type: datatype.Varchar}, 124 | {Keyspace: "system", Table: "schema_columns", Name: "validator", Type: datatype.Varchar}, 125 | } 126 | 127 | SystemSchemaUsertypes = []*message.ColumnMetadata{ 128 | {Keyspace: "system", Table: "schema_usertypes", Name: "keyspace_name", Type: datatype.Varchar}, 129 | {Keyspace: "system", Table: "schema_usertypes", Name: "type_name", Type: datatype.Varchar}, 130 | {Keyspace: "system", Table: "schema_usertypes", Name: "field_names", Type: datatype.NewList(datatype.Varchar)}, 131 | {Keyspace: "system", Table: "schema_usertypes", Name: "field_types", Type: datatype.NewList(datatype.Varchar)}, 132 | } 133 | ) 134 | 135 | var SystemColumnsByName = map[string][]*message.ColumnMetadata{ 136 | "local": SystemLocalColumns, 137 | "peers": SystemPeersColumns, 138 | "schema_keyspaces": SystemSchemaKeyspaces, 139 | "schema_columnfamilies": SystemSchemaColumnFamilies, 140 | "schema_columns": SystemSchemaColumns, 141 | "schema_usertypes": SystemSchemaUsertypes, 142 | } 143 | 144 | func FindColumnMetadata(columns []*message.ColumnMetadata, name string) *message.ColumnMetadata { 145 | for _, column := range columns { 146 | if column.Name == name { 147 | return column 148 | } 149 | } 150 | return nil 151 | } 152 | -------------------------------------------------------------------------------- /parser/parse_batch.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | import "errors" 18 | 19 | // Determines if a batch statement is idempotent. 20 | // 21 | // A batch statement is not idempotent if: 22 | // * it updates counters 23 | // * contains DML statements that are not idempotent 24 | // 25 | // batchStatement: 'BEGIN' ( 'UNLOGGED' | 'COUNTER' )? 'BATCH' 26 | // usingClause? 27 | // ( batchChildStatement ';'? )* 28 | // 'APPLY' 'BATCH' 29 | // 30 | // batchChildStatement: insertStatement | updateStatement | deleteStatement 31 | // 32 | func isIdempotentBatchStmt(l *lexer) (idempotent bool, err error) { 33 | t := l.next() 34 | 35 | if isUnreservedKeyword(l, t, "unlogged") { 36 | t = l.next() 37 | } else if isUnreservedKeyword(l, t, "counter") { 38 | return false, nil // Updates to counters are not idempotent 39 | } 40 | 41 | if tkBatch != t { 42 | return false, errors.New("expected 'BATCH' at the beginning of a batch statement") 43 | } 44 | 45 | t, err = parseUsingClause(l, l.next()) 46 | if err != nil { 47 | return false, err 48 | } 49 | 50 | for tkApply != t && tkEOF != t { 51 | switch t { 52 | case tkInsert: 53 | idempotent, t, err = isIdempotentInsertStmt(l) 54 | case tkUpdate: 55 | idempotent, t, err = isIdempotentUpdateStmt(l) 56 | case tkDelete: 57 | idempotent, t, err = isIdempotentDeleteStmt(l) 58 | default: 59 | return false, errors.New("unexpected child statement in batch statement") 60 | } 61 | if t == tkEOS { // Skip ';' 62 | t = l.next() 63 | } 64 | if !idempotent { 65 | return idempotent, err 66 | } 67 | } 68 | 69 | if tkApply != t { 70 | return false, errors.New("expected 'APPLY' after child statements at the end of a batch statement") 71 | } 72 | 73 | if tkBatch != l.next() { 74 | return false, errors.New("expected 'BATCH' at the end of a batch statement") 75 | } 76 | 77 | return true, nil 78 | } 79 | -------------------------------------------------------------------------------- /parser/parse_batch_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | import ( 18 | "testing" 19 | 20 | "github.com/stretchr/testify/assert" 21 | ) 22 | 23 | func TestIsIdempotentBatchStmt(t *testing.T) { 24 | var tests = []struct { 25 | query string 26 | idempotent bool 27 | hasError bool 28 | msg string 29 | }{ 30 | // Idempotent 31 | {`BEGIN BATCH 32 | INSERT INTO table (a, b, c) VALUES (1, 'a', 0.1) 33 | APPLY BATCH`, 34 | true, false, "simple"}, 35 | {`BEGIN BATCH 36 | INSERT INTO table (a, b, c) VALUES (1, 'a', 0.1) 37 | APPLY BATCH;`, 38 | true, false, "semicolon at the end of the batch"}, 39 | {`BEGIN BATCH 40 | INSERT INTO table (a, b, c) VALUES (1, 'a', 0.1); 41 | INSERT INTO table (a, b, c) VALUES (2, 'b', 0.2); 42 | APPLY BATCH;`, 43 | true, false, "semicolon at the end of child statements"}, 44 | {`BEGIN BATCH 45 | APPLY BATCH`, 46 | true, false, "empty"}, 47 | {`BEGIN BATCH 48 | UPDATE ks.table SET b = 0 WHERE a > 100 49 | DELETE a FROM ks.table 50 | INSERT INTO table (a, b, c) VALUES (1, 'a', 0.1) 51 | INSERT INTO table (a, b, c) VALUES (2, 'b', 0.2) 52 | APPLY BATCH`, 53 | true, false, "multiple statements"}, 54 | 55 | // Invalid 56 | {`BATCH 57 | SELECT * FROM table 58 | APPLY BATCH`, 59 | false, true, "no starting 'BEGIN'"}, 60 | {`BEGIN BATCH 61 | SELECT * FROM table 62 | APPLY BATCH`, 63 | false, true, "contains 'SELECT'"}, 64 | {`BEGIN 65 | INSERT INTO table (a, b, c) VALUES (1, 'a', 0.1) 66 | APPLY BATCH`, 67 | false, true, "no starting 'BATCH'"}, 68 | {`BEGIN BATCH 69 | INSERT INTO table (a, b, c) VALUES (1, 'a', 0.1) 70 | BATCH`, 71 | false, true, "no ending 'APPLY'"}, 72 | {`BEGIN BATCH 73 | INSERT INTO table (a, b, c) VALUES (1, 'a', 0.1) 74 | APPLY`, 75 | false, true, "no ending 'BATCH'"}, 76 | 77 | // Not idempotent 78 | {`BEGIN COUNTER BATCH 79 | INSERT INTO table (a, b) VALUES ('a', 0) 80 | APPLY BATCH`, 81 | false, false, "batch counter insert"}, 82 | {`BEGIN COUNTER BATCH 83 | UPDATE table SET a = a + 1 84 | APPLY BATCH`, 85 | false, false, "batch counter update"}, 86 | {`BEGIN BATCH 87 | INSERT INTO table (a, b, c) VALUES (1, 'a', 0.1) 88 | DELETE a, b, c[1] FROM ks.table; 89 | APPLY BATCH`, 90 | false, false, "delete from list in batch"}, 91 | {`BEGIN BATCH 92 | INSERT INTO table (a, b, c) VALUES (1, 'a', 0.1) 93 | INSERT INTO table (a, b, c) VALUES (now(), 'a', 0.1) 94 | APPLY BATCH;`, 95 | false, false, "contains now()"}, 96 | // Found defect 97 | {"BEGIN BATCH USING TIMESTAMP 1481124356754405\nINSERT INTO cycling.cyclist_expenses \n (cyclist_name, expense_id, amount, description, paid) \n VALUES ('Vera ADRIAN', 2, 13.44, 'Lunch', true);\nINSERT INTO cycling.cyclist_expenses \n (cyclist_name, expense_id, amount, description, paid) \n VALUES ('Vera ADRIAN', 3, 25.00, 'Dinner', true);\nAPPLY BATCH;", 98 | true, false, "has semicolons after each statement"}, 99 | } 100 | 101 | for _, tt := range tests { 102 | idempotent, err := IsQueryIdempotent(tt.query) 103 | assert.True(t, (err != nil) == tt.hasError, tt.msg) 104 | assert.Equal(t, tt.idempotent, idempotent, "invalid idempotency", tt.msg) 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /parser/parse_delete.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | import "errors" 18 | 19 | // Determines if a delete statement is idempotent. 20 | // 21 | // A delete statement not idempotent if: 22 | // * removes an element from a list 23 | // * uses a lightweight transaction (LWT) e.g. 'IF EXISTS' or 'IF a > 0' 24 | // * has a relation that uses a non-idempotent function e.g. now() or uuid() 25 | // 26 | // deleteStatement: 'DELETE' deleteOperations? 'FROM' tableName ( 'USING' timestamp )? whereClause ( 'IF' ( 'EXISTS' | conditions ) )? 27 | // deleteOperations: deleteOperation ( ',' deleteOperation )* 28 | // deleteOperation: identifier | identifier '[' term ']'| identifier '.' identifier 29 | // tableName: ( identifier '.' )? identifier 30 | // 31 | func isIdempotentDeleteStmt(l *lexer) (idempotent bool, t token, err error) { 32 | t = l.next() 33 | for ; tkFrom != t && tkEOF != t; t = skipToken(l, l.next(), tkComma) { 34 | if tkIdentifier != t { 35 | return false, tkInvalid, errors.New("unexpected token after 'DELETE' in delete statement") 36 | } 37 | 38 | l.mark() 39 | switch t = l.next(); t { 40 | case tkLsquare: 41 | var typ termType 42 | if idempotent, typ, err = parseTerm(l, l.next()); !idempotent { 43 | return idempotent, tkInvalid, err 44 | } 45 | if tkRsquare != l.next() { 46 | return false, tkInvalid, errors.New("expected closing ']' for the delete operation") 47 | } 48 | if !isIdempotentDeleteElementTermType(typ) { 49 | return false, tkInvalid, nil 50 | } 51 | case tkDot: 52 | if tkIdentifier != l.next() { 53 | return false, tkInvalid, errors.New("expected another identifier after '.' for delete operation") 54 | } 55 | default: 56 | l.rewind() 57 | } 58 | } 59 | 60 | if tkFrom != t { 61 | return false, tkInvalid, errors.New("expected 'FROM' after delete operation(s) in delete statement") 62 | } 63 | 64 | if tkIdentifier != l.next() { 65 | return false, tkInvalid, errors.New("expected identifier after 'FROM' in delete statement") 66 | } 67 | 68 | _, _, t, err = parseQualifiedIdentifier(l) 69 | if err != nil { 70 | return false, tkInvalid, err 71 | } 72 | 73 | t, err = parseUsingClause(l, t) 74 | if err != nil { 75 | return false, tkInvalid, err 76 | } 77 | 78 | if tkWhere == t { 79 | idempotent, t, err = parseWhereClause(l) 80 | if !idempotent { 81 | return idempotent, tkInvalid, err 82 | } 83 | } 84 | 85 | for ; !isDMLTerminator(t); t = l.next() { 86 | if tkIf == t { 87 | return false, tkInvalid, nil 88 | } 89 | } 90 | return true, t, nil 91 | } 92 | 93 | // Delete element terms can be one of the following: 94 | // * Literal (idempotent, if not an integer literal) 95 | // * Bind marker (ambiguous, so not idempotent) 96 | // * Function call (ambiguous, so not idempotent) 97 | // * Type cast (ambiguous) 98 | func isIdempotentDeleteElementTermType(typ termType) bool { 99 | return typ != termIntegerLiteral && typ != termBindMarker && typ != termFunctionCall && typ != termCast 100 | } 101 | -------------------------------------------------------------------------------- /parser/parse_delete_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | import ( 18 | "testing" 19 | 20 | "github.com/stretchr/testify/assert" 21 | ) 22 | 23 | func TestIsIdempotentDeleteStmt(t *testing.T) { 24 | var tests = []struct { 25 | query string 26 | idempotent bool 27 | hasError bool 28 | msg string 29 | }{ 30 | {"DELETE FROM table", true, false, "simple"}, 31 | {"DELETE FROM table;", true, false, "semicolon at the end"}, 32 | {"DELETE a FROM table", true, false, "w/ operation"}, 33 | {"DELETE a FROM ks.table", true, false, "simple qualified table"}, 34 | {"DELETE a.b FROM table", true, false, "UDT field"}, 35 | {"DELETE a, b, c['key'] FROM table", true, false, "multiple operations"}, 36 | {"DELETE a['key'] FROM ks.table", true, false, "map field"}, 37 | {"DELETE a FROM table WHERE a > 0", true, false, "where clause"}, 38 | 39 | // Invalid 40 | {"DELETE a. FROM table", false, true, "no UDT field"}, 41 | {"DELETE FROM ks.", false, true, "no table after '.'"}, 42 | {"DELETE FROM table WHERE", true, false, "where clause w/ no relation"}, 43 | {"DELETE a[0 table WHERE b > 0", false, true, "collection element with no closing square bracket"}, 44 | 45 | // Not idempotent 46 | {"DELETE a, b, c[1] FROM ks.table", false, false, "multiple with list element"}, 47 | {"DELETE FROM ks.table WHERE a > toTimestamp(now())", false, false, "now() relation"}, 48 | {"DELETE FROM table WHERE a > 0 IF EXISTS", false, false, "LWT"}, 49 | {"DELETE a['key'] FROM table WHERE a > 0 IF EXISTS", false, false, "LWT w/ map field"}, 50 | {"DELETE a['key'] FROM table WHERE a > 0 IF EXISTS;", false, false, "LWT w/ map field and semicolon"}, 51 | 52 | // Ambiguous 53 | {"DELETE a[0] FROM ks.table", false, false, "potentially a list element"}, 54 | {"DELETE a[?] FROM ks.table", false, false, "potentially a list element w/ bind marker"}, 55 | {"DELETE a[func()] FROM ks.table", false, false, "potentially a list element w/ function call"}, 56 | } 57 | 58 | for _, tt := range tests { 59 | idempotent, err := IsQueryIdempotent(tt.query) 60 | assert.True(t, (err != nil) == tt.hasError, tt.msg) 61 | assert.Equal(t, tt.idempotent, idempotent, "invalid idempotency", tt.msg) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /parser/parse_insert.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | import "errors" 18 | 19 | // Determines if an insert statement is idempotent. 20 | // 21 | // An insert statement is not idempotent if it contains a non-idempotent term e.g. 'now()' or 'uuid()' or if it uses 22 | // lightweight transactions (LWTs) e.g. using 'IF NOT EXISTS'. 23 | // 24 | // insertStatement: 'INSERT' 'INTO' tableName ( namedValues | jsonClause ) ( 'IF' 'NOT' 'EXISTS' )? 25 | // tableName: ( identifier '.' )? identifier 26 | // namedValues: '(' identifiers ')' 'VALUES' '(' terms ')' 27 | // jsonClause: 'JSON' stringLiteral ( 'DEFAULT' ( 'NULL' | 'UNSET' ) )? 28 | // 29 | func isIdempotentInsertStmt(l *lexer) (idempotent bool, t token, err error) { 30 | t = l.next() 31 | if tkInto != t { 32 | return false, tkInvalid, errors.New("expected 'INTO' after 'INSERT' for insert statement") 33 | } 34 | 35 | if t = l.next(); tkIdentifier != t { 36 | return false, tkInvalid, errors.New("expected identifier after 'INTO' in insert statement") 37 | } 38 | 39 | _, _, t, err = parseQualifiedIdentifier(l) 40 | if err != nil { 41 | return false, tkInvalid, err 42 | } 43 | 44 | if !isUnreservedKeyword(l, t, "json") { 45 | if tkLparen != t { 46 | return false, tkInvalid, errors.New("expected '(' after table name for insert statement") 47 | } 48 | 49 | err = parseIdentifiers(l, l.next()) 50 | if err != nil { 51 | return false, tkInvalid, err 52 | } 53 | 54 | if !isUnreservedKeyword(l, l.next(), "values") { 55 | return false, tkInvalid, errors.New("expected 'VALUES' after identifiers in insert statement") 56 | } 57 | 58 | if t != l.next() { 59 | return false, tkInvalid, errors.New("expected '(' after 'VALUES' in insert statement") 60 | } 61 | 62 | for t = l.next(); tkRparen != t && tkEOF != t; t = skipToken(l, l.next(), tkComma) { 63 | if idempotent, _, err = parseTerm(l, t); !idempotent { 64 | return idempotent, tkInvalid, err 65 | } 66 | } 67 | 68 | if t != tkRparen { 69 | return false, tkInvalid, errors.New("expected closing ')' for 'VALUES' list in insert statement") 70 | } 71 | } 72 | 73 | for t = l.next(); !isDMLTerminator(t); t = l.next() { 74 | if tkIf == t { 75 | return false, tkInvalid, nil 76 | } 77 | } 78 | return true, t, nil 79 | } 80 | -------------------------------------------------------------------------------- /parser/parse_insert_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | import ( 18 | "testing" 19 | 20 | "github.com/stretchr/testify/assert" 21 | ) 22 | 23 | func TestIsIdempotentInsertStmt(t *testing.T) { 24 | var tests = []struct { 25 | query string 26 | idempotent bool 27 | hasError bool 28 | msg string 29 | }{ 30 | {"INSERT INTO table (a, b, c) VALUES (1, 'a', 0.1)", true, false, "simple"}, 31 | {"INSERT INTO table (a, b, c) VALUES (1, 'a', 0.1);", true, false, "semicolon"}, 32 | {"INSERT INTO ks.table (a, b, c) VALUES (1, 'a', 0.1)", true, false, "simple qualified table name"}, 33 | {"INSERT INTO table () VALUES ()", true, false, "no identifier of values"}, 34 | {"INSERT INTO table JSON '{}'", true, false, "JSON"}, 35 | 36 | // Invalid 37 | {"INSERT table (a, b, c) VALUES (1, 'a', 0.1)", false, true, "missing 'INTO'"}, 38 | {"INSERT INTO (a, b, c) VALUES (1, 'a', 0.1)", false, true, "missing table name"}, 39 | {"INSERT INTO table a, b, c) VALUES (1, 'a', 0.1)", false, true, "missing opening paren. on identifiers"}, 40 | {"INSERT INTO table (a, b, c VALUES (1, 'a', 0.1)", false, true, "missing closing paren on identifiers"}, 41 | {"INSERT INTO table (a, b, c) (1, 'a', 0.1)", false, true, "missing 'VALUES'"}, 42 | {"INSERT INTO table (0, b, c) VALUES (1, 'a', 0.1)", false, true, "unexpected term in identifiers"}, 43 | {"INSERT INTO table (a, b, c) VALUES (invalid, 'a', 0.1)", false, true, "invalid value"}, 44 | {"INSERT INTO table (a, b, c) VALUES 1, 'a', 0.1)", false, true, "missing opening paren. on values"}, 45 | {"INSERT INTO table (a, b, c) VALUES (1, 'a', 0.1", false, true, "missing closing paren. on values"}, 46 | 47 | // Not idempotent 48 | {"INSERT INTO table (a, b, c) VALUES (now(), 'a', 0.1)", false, false, "simple w/ 'now()'"}, 49 | {"INSERT INTO table (a, b, c) VALUES (0, uuid(), 0.1)", false, false, "simple w/ 'uuid()'"}, 50 | {"INSERT INTO table (a, b, c) VALUES (1, 'a', 0.1) IF NOT EXISTS", false, false, "simple w/ LWT"}, 51 | {"INSERT INTO table () VALUES () IF NOT EXIST", false, false, "no identifier of values w/ LWT"}, 52 | {"INSERT INTO table JSON '{}' IF NOT EXIST", false, false, "'JSON' w/ LWT"}, 53 | {"INSERT INTO table JSON '{}' IF NOT EXIST;", false, false, "'JSON' w/ LWT and semicolon"}, 54 | } 55 | 56 | for _, tt := range tests { 57 | idempotent, err := IsQueryIdempotent(tt.query) 58 | assert.True(t, (err != nil) == tt.hasError, tt.msg) 59 | assert.Equal(t, tt.idempotent, idempotent, "invalid idempotency", tt.msg) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /parser/parse_relation.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | import "errors" 18 | 19 | // Determine if where clause is idempotent for an UPDATE or DELETE mutation. 20 | // 21 | // whereClause: 'WHERE' relation ( 'AND' relation )* 22 | // 23 | func parseWhereClause(l *lexer) (idempotent bool, t token, err error) { 24 | for t = l.next(); tkIf != t && !isDMLTerminator(t); t = skipToken(l, l.next(), tkAnd) { 25 | idempotent, err = parseRelation(l, t) 26 | if !idempotent { 27 | return idempotent, tkInvalid, err 28 | } 29 | } 30 | return true, t, nil 31 | } 32 | 33 | // Determine if a relation is idempotent for an UPDATE or DELETE mutation. 34 | // 35 | // relation 36 | // : identifier operator term 37 | // | 'TOKEN' '(' identifiers ')' operator term 38 | // | identifier 'LIKE' term 39 | // | identifier 'IS' 'NOT' 'NULL' 40 | // | identifier 'CONTAINS' 'KEY'? term 41 | // | identifier '[' term ']' operator term 42 | // | identifier 'IN' ( '(' terms? ')' | bindMarker ) 43 | // | '(' identifiers ')' 'IN' ( '(' terms? ')' | bindMarker ) 44 | // | '(' identifiers ')' operator ( '(' terms? ')' | bindMarker ) 45 | // | '(' relation ')' 46 | // 47 | func parseRelation(l *lexer, t token) (idempotent bool, err error) { 48 | switch t { 49 | case tkIdentifier: 50 | switch t = l.next(); t { 51 | case tkIdentifier: 52 | if isUnreservedKeyword(l, t, "contains") { // identifier 'contains' 'key'? term 53 | if t = l.next(); isUnreservedKeyword(l, t, "key") { 54 | t = l.next() 55 | } 56 | if idempotent, _, err = parseTerm(l, t); !idempotent { 57 | return idempotent, err 58 | } 59 | 60 | } else if isUnreservedKeyword(l, t, "like") { // identifier 'like' term 61 | if idempotent, _, err = parseTerm(l, l.next()); !idempotent { 62 | return idempotent, err 63 | } 64 | } else { 65 | return false, errors.New("unexpected token parsing relation") 66 | } 67 | case tkEqual, tkRangle, tkLtEqual, tkLangle, tkGtEqual, tkNotEqual: // identifier operator term 68 | if idempotent, _, err = parseTerm(l, l.next()); !idempotent { 69 | return idempotent, err 70 | } 71 | case tkIs: // identifier 'is' 'not' 'null' 72 | if t = l.next(); tkNot != t { 73 | return false, errors.New("expected 'not' in relation after 'is'") 74 | } 75 | if t = l.next(); tkNull != t { 76 | return false, errors.New("expected 'null' in relation after 'is not'") 77 | } 78 | case tkLsquare: // identifier '[' term ']' operator term 79 | if idempotent, _, err = parseTerm(l, l.next()); !idempotent { 80 | return idempotent, err 81 | } 82 | if t = l.next(); tkRsquare != t { 83 | return false, errors.New("expected closing ']' after term in relation") 84 | } 85 | if t = l.next(); !isOperator(t) { 86 | return false, errors.New("expected operator after term in relation") 87 | } 88 | if idempotent, _, err = parseTerm(l, l.next()); !idempotent { 89 | return idempotent, err 90 | } 91 | case tkIn: // identifier 'in' ('(' terms? ')' | bindMarker) 92 | switch t = l.next(); t { 93 | case tkLparen: 94 | for t = l.next(); tkRparen != t && tkEOF != t; t = skipToken(l, l.next(), tkComma) { 95 | if idempotent, _, err = parseTerm(l, t); !idempotent { 96 | return idempotent, err 97 | } 98 | } 99 | if tkRparen != t { 100 | return false, errors.New("expected closing ')' after terms") 101 | } 102 | case tkColon, tkQMark: 103 | err = parseBindMarker(l, t) 104 | if err != nil { 105 | return false, err 106 | } 107 | default: 108 | return false, errors.New("unexpected token for 'IN' relation") 109 | } 110 | default: 111 | return false, errors.New("unexpected token parsing relation") 112 | } 113 | case tkToken: // token '(' identifiers ')' operator term 114 | if t = l.next(); tkLparen != t { 115 | return false, errors.New("expected '(' after 'token'") 116 | } 117 | err = parseIdentifiers(l, l.next()) 118 | if err != nil { 119 | return false, err 120 | } 121 | if t = l.next(); !isOperator(t) { 122 | return false, errors.New("expected operator after identifier list in relation") 123 | } 124 | if idempotent, _, err = parseTerm(l, l.next()); !idempotent { 125 | return idempotent, err 126 | } 127 | case tkLparen: // '(' relation ')' | '(' identifiers ')' ... 128 | l.mark() 129 | maybeId, maybeCommaOrRparen := l.next(), l.next() // Peek a couple tokens to see if this is an identifier list 130 | if tkIdentifier == maybeId && (maybeCommaOrRparen == tkComma || maybeCommaOrRparen == tkRparen) { 131 | t = skipToken(l, maybeCommaOrRparen, tkComma) 132 | err = parseIdentifiers(l, t) 133 | if err != nil { 134 | return false, err 135 | } 136 | return parseIdentifiersRelation(l) 137 | } else { 138 | l.rewind() 139 | idempotent, err = parseRelation(l, l.next()) 140 | if !idempotent { 141 | return idempotent, err 142 | } 143 | if tkRparen != l.next() { 144 | return false, errors.New("expected closing ')' after parenthesized relation") 145 | } 146 | } 147 | default: 148 | return false, errors.New("unexpected token in relation") 149 | } 150 | return true, nil 151 | } 152 | 153 | // Determines if identifiers relation is idempotent. 154 | // 155 | // ... 'IN' ( '(' terms? ')' | bindMarker ) 156 | // ... operator ( '(' terms? ')' | bindMarker ) 157 | // 158 | func parseIdentifiersRelation(l *lexer) (idempotent bool, err error) { 159 | switch t := l.next(); t { 160 | case tkIn, tkEqual, tkLt, tkLtEqual, tkGt, tkGtEqual, tkNotEqual: 161 | switch t = l.next(); t { 162 | case tkColon, tkQMark: 163 | err = parseBindMarker(l, t) 164 | if err != nil { 165 | return false, err 166 | } 167 | case tkLparen: 168 | for t = l.next(); tkRparen != t && tkEOF != t; t = skipToken(l, l.next(), tkComma) { 169 | if idempotent, _, err = parseTerm(l, t); !idempotent { 170 | return idempotent, err 171 | } 172 | } 173 | if tkRparen != t { 174 | return false, errors.New("expected closing ')' in identifiers relation") 175 | } 176 | default: 177 | return false, errors.New("unexpected term token in identifiers relation") 178 | } 179 | default: 180 | return false, errors.New("unexpected token in identifiers relation") 181 | } 182 | 183 | return true, nil 184 | } 185 | 186 | // Parses the remainder of a bind marker. 187 | // 188 | // bindMarker 189 | // : ':' identifier 190 | // | '?' 191 | func parseBindMarker(l *lexer, t token) error { 192 | switch t { 193 | case tkColon: 194 | if tkIdentifier != l.next() { 195 | return errors.New("expected identifier after ':' for named bind marker") 196 | } 197 | case tkQMark: 198 | // Do nothing 199 | default: 200 | return errors.New("invalid bind marker") 201 | } 202 | return nil 203 | } 204 | 205 | func isOperator(t token) bool { 206 | return tkEqual == t || tkLt == t || tkLtEqual == t || tkGt == t || tkGtEqual == t || tkNotEqual == t 207 | } 208 | -------------------------------------------------------------------------------- /parser/parse_relation_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | import ( 18 | "testing" 19 | 20 | "github.com/stretchr/testify/assert" 21 | ) 22 | 23 | func TestParseRelation(t *testing.T) { 24 | var tests = []struct { 25 | relation string 26 | idempotent bool 27 | hasError bool 28 | msg string 29 | }{ 30 | {"id > 0", true, false, "simple operator relation"}, 31 | {"token (a, b, c) > (0, 1, 2)", true, false, "'token' relation"}, 32 | {"id LIKE 'abc'", true, false, "'like' relation"}, 33 | {"id IS NOT NULL", true, false, "'is not null' relation"}, 34 | {"id CONTAINS 'abc'", true, false, "'contains' relation"}, 35 | {"id CONTAINS KEY 'abc'", true, false, "'contains key' relation"}, 36 | {"id[0] > 0", true, false, "index collection w/ int relation"}, 37 | {"id['abc'] > 'def'", true, false, "index collection w/string relation"}, 38 | {"id IN ?", true, false, "'IN' w/ position bind marker relation "}, 39 | {"id IN :column", true, false, "'IN' w/ named bind marker relation"}, 40 | {"id IN (0, 1, 2)", true, false, "'IN' w/ terms"}, 41 | {"((((id > 0))))", true, false, "arbitrary number of parens"}, 42 | {"(id1, id2, id3) IN ()", true, false, "list in empty"}, 43 | {"(id1, id2, id3) IN ?", true, false, "list in positional bind marker"}, 44 | {"(id1, id2, id3) IN :named", true, false, "list in named bind marker"}, 45 | {"(id1, id2, id3) IN (?, ?, :named)", true, false, "list in list of bind markers"}, 46 | {"(id1, id2, id3) IN (('a', ?, 0), ('b', :named, 1))", true, false, "list in list of tuples"}, 47 | {"(id1, id2, id3) > ?", true, false, "list in positional bind marker"}, 48 | {"(id1, id2, id3) < :named", true, false, "list in named bind marker"}, 49 | {"(id1, id2, id3) >= (?, ?, :named)", true, false, "list in list of bind markers"}, 50 | {"(id1, id2, id3) <= (('a', ?, 0), ('b', :named, 1))", true, false, "list in list of tuples"}, 51 | 52 | // Invalid 53 | {"id >", false, true, "missing term"}, 54 | {"id == 0", false, true, "invalid operator"}, 55 | {"token a, b, c) > (0, 1, 2)", false, true, "invalid 'token' relation w/ missing identifiers opening paren"}, 56 | {"token (a, b, c > (0, 1, 2)", false, true, "invalid 'token' relation w/ missing identifiers closing paren"}, 57 | {"token (a, b, c) > (0, 1, 2", false, true, "invalid 'token' relation w/ missing tuple closing paren"}, 58 | {"id IS", false, true, "invalid 'is not null' relation"}, 59 | {"id CONTAINS", false, true, "invalid 'contains' relation w/ missing term"}, 60 | {"id CONTAINS KEY", false, true, " invalid 'contains key' relation w/ missing term"}, 61 | {"id[0 > 0", false, true, "invalid index collection w/ int relation w/ missing closing square bracket"}, 62 | {"id[0] >", false, true, "invalid index collection w/ int relation w/ missing term"}, 63 | {"id LIKE", false, true, "invalid 'like' relation w/ missing term"}, 64 | {"id IN", false, true, "invalid 'IN' relation w/ missing bind marker/term"}, 65 | {"id IN 0", false, true, "invalid 'IN' relation w/ unexpect term"}, 66 | {"id IN (", false, true, "invalid 'IN' relation w/ missing closing paren and empty"}, 67 | {"id IN ('a'", false, true, "invalid 'IN' relation w/ missing closing paren"}, 68 | {"(id1, id2)", false, true, "invalid identifiers w/ no operator"}, 69 | {"id1, id2) IN ()", false, true, "invalid identifiers w/ missing opening paren"}, 70 | {"(id1, id2 IN ()", false, true, "invalid identifiers w/ missing closing paren"}, 71 | {"(id1, id2) IN ('a', 1", false, true, "invalid identifiers w/ missing terms closing paren"}, 72 | {"(id1, id2) == ('a', 1)", false, true, "invalid identifiers w/ invalid operator"}, 73 | 74 | // Not idempotent 75 | {"id > now()", false, false, "simple operator relation w/ 'now()'"}, 76 | {"id LIKE now()", false, false, "'like' relation w/ 'now()'"}, 77 | {"id CONTAINS now()", false, false, "'contains' relation w/ 'now()'"}, 78 | {"id CONTAINS KEY now()", false, false, "'contains key' relation w/ 'now()'"}, 79 | {"id1 IN (now(), uuid())", false, false, "'in' relation w/ 'now()'"}, 80 | {"(id1, id2) IN (now(), uuid())", false, false, "list 'IN' relation w/ 'now()'"}, 81 | {"(id1, id2) < (now(), uuid())", false, false, "list operator reation w/ 'now()'"}, 82 | } 83 | 84 | for _, tt := range tests { 85 | var l lexer 86 | l.init(tt.relation) 87 | idempotent, err := parseRelation(&l, l.next()) 88 | assert.Equal(t, tt.idempotent, idempotent, tt.msg) 89 | assert.True(t, (err != nil) == tt.hasError, tt.msg) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /parser/parse_select.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | import ( 18 | "errors" 19 | "fmt" 20 | "strings" 21 | ) 22 | 23 | // Determines is the proxy handles the select statement. 24 | // 25 | // Currently, the only handled 'SELECT' queries are for tables in the 'system' keyspace and are matched by the 26 | // `isSystemTable()` function. This includes 'system.local' 'system.peers/peers_v2', and legacy schema tables. 27 | // 28 | // selectStmt: 'SELECT' 'JSON'? 'DISTINCT'? 'FROM' selectClause ... 29 | // selectClause: '*' | selectors 30 | // 31 | // Note: Exclusiveness of '*' not enforced 32 | func isHandledSelectStmt(l *lexer, keyspace Identifier) (handled bool, stmt Statement, err error) { 33 | l.mark() // Mark here because we might come back to parse the selector 34 | t := untilToken(l, tkFrom) 35 | 36 | if tkFrom != t { 37 | return false, nil, errors.New("expected 'FROM' in select statement") 38 | } 39 | 40 | if t = l.next(); tkIdentifier != t { 41 | return false, nil, errors.New("expected identifier after 'FROM' in select statement") 42 | } 43 | 44 | qualifyingKeyspace, table, t, err := parseQualifiedIdentifier(l) 45 | if err != nil || (!keyspace.equal("system") && !qualifyingKeyspace.equal("system")) || !isSystemTable(table) { 46 | return false, nil, err 47 | } 48 | 49 | selectStmt := &SelectStatement{Keyspace: "system", Table: table.id} 50 | 51 | // This only parses the selectors if this is a query handled by the proxy 52 | 53 | l.rewind() // Rewind to the selectors 54 | for t = l.next(); tkFrom != t && tkEOF != t; t = skipToken(l, t, tkComma) { 55 | if tkIdentifier == t && (isUnreservedKeyword(l, t, "json") || isUnreservedKeyword(l, t, "distinct")) { 56 | return true, nil, errors.New("proxy is unable to do 'JSON' or 'DISTINCT' for handled system queries") 57 | } 58 | var selector Selector 59 | selector, t, err = parseSelector(l, t) 60 | if err != nil { 61 | return true, nil, err 62 | } 63 | selectStmt.Selectors = append(selectStmt.Selectors, selector) 64 | } 65 | 66 | return true, selectStmt, nil 67 | } 68 | 69 | func isHandledUseStmt(l *lexer) (handled bool, stmt Statement, err error) { 70 | t := l.next() 71 | if tkIdentifier != t { 72 | return false, nil, errors.New("expected identifier after 'USE' in use statement") 73 | } 74 | return true, &UseStatement{Keyspace: l.identifierStr()}, nil 75 | } 76 | 77 | // Parses selectors in the select clause of a select statement. 78 | // 79 | // selectors: selector ( ',' selector )* 80 | // selector: unaliasedSelector ( 'AS' identifier ) 81 | // unaliasedSelector: 82 | // 83 | // identifier 84 | // 'COUNT(*)' | 'COUNT' '(' identifier ')' | NOW()' 85 | // term 86 | // 'CAST' '(' unaliasedSelector 'AS' primitiveType ')' 87 | // 88 | // Note: Doesn't handle term or cast 89 | func parseSelector(l *lexer, t token) (selector Selector, next token, err error) { 90 | switch t { 91 | case tkIdentifier: 92 | name := l.identifierStr() 93 | l.mark() 94 | if tkLparen == l.next() { 95 | var args []string 96 | for t = l.next(); tkRparen != t && tkEOF != t; t = skipToken(l, l.next(), tkComma) { 97 | if tkStar == t { 98 | args = append(args, "*") 99 | } else if tkIdentifier == t { 100 | args = append(args, l.identifierStr()) 101 | } else { 102 | return nil, tkInvalid, fmt.Errorf("unexpected argument type for function call '%s(...)' in select statement", name) 103 | } 104 | } 105 | if tkRparen != t { 106 | return nil, tkInvalid, fmt.Errorf("expected closing ')' for function call '%s' in select statement", name) 107 | } 108 | if strings.EqualFold(name, "count") { 109 | if len(args) == 0 { 110 | return nil, tkInvalid, fmt.Errorf("expected * or identifier in argument 'COUNT(...)' in select statement") 111 | } 112 | return &CountFuncSelector{Arg: args[0]}, l.next(), nil 113 | } else if strings.EqualFold(name, "now") { 114 | if len(args) != 0 { 115 | return nil, tkInvalid, fmt.Errorf("unexpected argument for 'NOW()' function call in select statement") 116 | } 117 | return &NowFuncSelector{}, l.next(), nil 118 | } else { 119 | return nil, tkInvalid, fmt.Errorf("unsupported function call '%s' in select statement", name) 120 | } 121 | } else { 122 | l.rewind() 123 | selector = &IDSelector{Name: name} 124 | } 125 | case tkStar: 126 | return &StarSelector{}, l.next(), nil 127 | default: 128 | return nil, tkInvalid, errors.New("unsupported select clause for system table") 129 | } 130 | 131 | if t = l.next(); isUnreservedKeyword(l, t, "as") { 132 | if tkIdentifier != l.next() { 133 | return nil, tkInvalid, errors.New("expected identifier after 'AS' in select statement") 134 | } 135 | return &AliasSelector{Selector: selector, Alias: l.identifierStr()}, l.next(), nil 136 | } 137 | 138 | return selector, t, nil 139 | } -------------------------------------------------------------------------------- /parser/parse_term.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | import ( 18 | "errors" 19 | ) 20 | 21 | type termType int 22 | 23 | const ( 24 | termInvalid termType = iota 25 | termIntegerLiteral // Special because it can be used to index lists for deletes 26 | termPrimitiveLiteral 27 | termListLiteral 28 | termSetMapUdtLiteral // All use curly, distinction not important 29 | termTupleLiteral 30 | termBindMarker 31 | termFunctionCall 32 | termCast 33 | ) 34 | 35 | // Determines if a term is idempotent and also returns the top-level type of the term. 36 | // 37 | // A term is not idempotent if it contains a non-idempotent function e.g. 'now()' or 'uuid()' 38 | // 39 | // term: literal | bindMarker | functionCall | typeCast 40 | // 41 | // literal: primitiveLiteral | collectionLiteral | tupleLiteral | udtLiteral | 'NULL' 42 | // primitiveLiteral: stringLiteral | integer | float | boolean | duration | uuid | hexNumber | '-'? 'NAN' | '-'? 'INFINITY' 43 | // collectionLiteral: listLiteral | setLiteral | mapLiteral 44 | // listLiteral: '[' terms? ']' 45 | // setLiteral: '{' terms? '}' 46 | // mapLiteral: '{' ( mapEntry ( ',' mapEntry )* )? '}' 47 | // mapEntry: term ':' term 48 | // tupleLiteral: '(' terms? ')' 49 | // udtLiteral: '{' ( fieldEntry ( ',' fieldEntry )* )? '}' 50 | // fieldEntry: identifier ':' term 51 | // 52 | // bindMarker: '?' | ':' identifier 53 | // 54 | // functionCall: ( identifier '.' )? identifier '(' functionArg ( ',' functionArg )* ')' 55 | // functionArg: identifier | term 56 | // 57 | // typeCast: '(' type ')' term 58 | // type: identifier | identifier '<' type '>' 59 | // 60 | func parseTerm(l *lexer, t token) (idempotent bool, typ termType, err error) { 61 | switch t { 62 | case tkInteger: // Integer lister 63 | return true, termIntegerLiteral, nil 64 | case tkFloat, tkBool, tkNull, tkStringLiteral, tkHexNumber, tkUuid, tkDuration, tkNan, tkInfinity: // Literal 65 | return true, termPrimitiveLiteral, nil 66 | case tkColon: // Named bind marker 67 | if t = l.next(); t != tkIdentifier { 68 | return false, termBindMarker, errors.New("expected identifier after ':' for named bind marker") 69 | } 70 | return true, termBindMarker, nil 71 | case tkQMark: // Positional bind marker 72 | return true, termBindMarker, nil 73 | case tkLsquare: // List literal 74 | return parseListTerm(l) 75 | case tkLcurly: // Set, map, or UDT literal 76 | if t = l.next(); t == tkIdentifier { // maybe UDT 77 | l.mark() 78 | var maybeColon token 79 | _, _, maybeColon, err = parseQualifiedIdentifier(l) 80 | if err != nil { 81 | return false, termSetMapUdtLiteral, err 82 | } 83 | l.rewind() 84 | if tkColon == maybeColon { // UDT 85 | return parseUDTTerm(l, t) 86 | } else { // Set or map (probably starting with a function) 87 | return parseSetOrMapTerm(l, t) 88 | } 89 | } else { // Set or map 90 | return parseSetOrMapTerm(l, t) 91 | } 92 | case tkLparen: // Type cast or tuple literal 93 | if t = l.next(); t == tkIdentifier { // Cast 94 | return parseCastTerm(l, t) 95 | } else { // Tuple literal 96 | return parseTupleTerm(l, t) 97 | } 98 | case tkIdentifier: // Function 99 | return parseFunctionTerm(l) 100 | } 101 | 102 | return false, termInvalid, errors.New("invalid term") 103 | } 104 | 105 | func parseListTerm(l *lexer) (idempotent bool, typ termType, err error) { 106 | var t token 107 | for t = l.next(); t != tkRsquare && t != tkEOF; t = skipToken(l, l.next(), tkComma) { 108 | if idempotent, _, err = parseTerm(l, t); !idempotent { 109 | return idempotent, termListLiteral, err 110 | } 111 | } 112 | if t != tkRsquare { 113 | return false, termListLiteral, errors.New("expected closing ']' bracket for list literal") 114 | } 115 | return true, termListLiteral, nil 116 | } 117 | 118 | func parseUDTTerm(l *lexer, t token) (idempotent bool, typ termType, err error) { 119 | for t != tkRcurly && t != tkEOF { 120 | if tkIdentifier != t { 121 | return false, termSetMapUdtLiteral, errors.New("expected identifier in UDT literal field") 122 | } 123 | _, _, t, err = parseQualifiedIdentifier(l) 124 | if err != nil { 125 | return false, termSetMapUdtLiteral, err 126 | } 127 | t = skipToken(l, l.next(), tkColon) 128 | if idempotent, typ, err = parseTerm(l, t); !idempotent { 129 | return idempotent, termSetMapUdtLiteral, err 130 | } 131 | t = skipToken(l, l.next(), tkComma) 132 | } 133 | if t != tkRcurly { 134 | return false, termSetMapUdtLiteral, errors.New("expected closing '}' bracket for UDT literal") 135 | } 136 | return true, termSetMapUdtLiteral, nil 137 | } 138 | 139 | func parseSetOrMapTerm(l *lexer, t token) (idempotent bool, typ termType, err error) { 140 | for t != tkRcurly && t != tkEOF { 141 | if idempotent, typ, err = parseTerm(l, t); !idempotent { 142 | return idempotent, termSetMapUdtLiteral, err 143 | } 144 | if t = l.next(); tkColon == t { // Map 145 | if idempotent, typ, err = parseTerm(l, l.next()); !idempotent { 146 | return idempotent, termSetMapUdtLiteral, err 147 | } 148 | t = l.next() 149 | } 150 | t = skipToken(l, t, tkComma) 151 | } 152 | if t != tkRcurly { 153 | return false, termSetMapUdtLiteral, errors.New("expected closing '}' bracket for set/map literal") 154 | } 155 | return true, termSetMapUdtLiteral, nil 156 | } 157 | 158 | func parseCastTerm(l *lexer, t token) (idempotent bool, typ termType, err error) { 159 | t, err = parseType(l) 160 | if err != nil { 161 | return false, termCast, err 162 | } 163 | if t != tkRparen { 164 | return false, termCast, errors.New("expected closing ')' bracket for type cast") 165 | } 166 | if idempotent, typ, err = parseTerm(l, l.next()); !idempotent { 167 | return idempotent, termCast, err 168 | } 169 | return true, termCast, err 170 | } 171 | 172 | func parseTupleTerm(l *lexer, t token) (idempotent bool, typ termType, err error) { 173 | for t != tkRparen && t != tkEOF { 174 | if idempotent, _, err = parseTerm(l, t); !idempotent { 175 | return idempotent, termTupleLiteral, err 176 | } 177 | t = skipToken(l, l.next(), tkComma) 178 | } 179 | if t != tkRparen { 180 | return false, termTupleLiteral, errors.New("expected closing ')' bracket for tuple literal") 181 | } 182 | return true, termTupleLiteral, nil 183 | } 184 | 185 | func parseFunctionTerm(l *lexer) (idempotent bool, typ termType, err error) { 186 | var target, keyspace Identifier 187 | keyspace, target, t, err := parseQualifiedIdentifier(l) 188 | if err != nil { 189 | return false, termFunctionCall, err 190 | } 191 | if tkLparen != t { 192 | return false, termFunctionCall, errors.New("invalid term, was expecting function call") 193 | } 194 | for t = l.next(); t != tkRparen && t != tkEOF; t = skipToken(l, l.next(), tkComma) { 195 | l.mark() 196 | maybeCommaOrRparen := l.next() 197 | if tkIdentifier == t && (tkComma == maybeCommaOrRparen || tkRparen == maybeCommaOrRparen) { 198 | l.rewind() 199 | } else { 200 | l.rewind() 201 | if idempotent, _, err = parseTerm(l, t); !idempotent { 202 | return idempotent, termFunctionCall, err 203 | } 204 | } 205 | } 206 | if t != tkRparen { 207 | return false, termFunctionCall, errors.New("expected closing ')' for function call") 208 | } 209 | return !(isNonIdempotentFunc(target) && (keyspace.isEmpty() || keyspace.equal("system"))), termFunctionCall, nil 210 | } 211 | 212 | func parseType(l *lexer) (t token, err error) { 213 | if t = l.next(); tkLangle == t { 214 | for t = l.next(); tkRangle != t && tkEOF != t; t = skipToken(l, l.next(), tkComma) { 215 | if t != tkIdentifier { 216 | return tkInvalid, errors.New("expected sub-type in type parameter") 217 | } 218 | } 219 | if tkRangle != t { 220 | return tkInvalid, errors.New("expected closing '>' bracket for type") 221 | } 222 | return l.next(), nil 223 | } 224 | return t, nil 225 | } 226 | -------------------------------------------------------------------------------- /parser/parse_term_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | import ( 18 | "testing" 19 | 20 | "github.com/stretchr/testify/assert" 21 | ) 22 | 23 | func TestParseTerm(t *testing.T) { 24 | var tests = []struct { 25 | term string 26 | idempotent bool 27 | typ termType 28 | hasError bool 29 | msg string 30 | }{ 31 | {"system.someOtherFunc()", true, termFunctionCall, false, "qualified idempotent function"}, 32 | {"[1, 2, 3]", true, termListLiteral, false, "list literal"}, 33 | {"123", true, termIntegerLiteral, false, "integer literal"}, 34 | {"{1, 2, 3}", true, termSetMapUdtLiteral, false, "set literal"}, 35 | {"{ a: 1, a.b: 2, c: 3}", true, termSetMapUdtLiteral, false, "UDT literal"}, 36 | {"{ 'a': 1, 'b': 2, 'c': 3}", true, termSetMapUdtLiteral, false, "map literal"}, 37 | {"(1, 'a', [])", true, termTupleLiteral, false, "tuple literal"}, 38 | {":abc", true, termBindMarker, false, "named bind marker"}, 39 | {"?", true, termBindMarker, false, "positional bind marker"}, 40 | {"(map)1", true, termCast, false, "cast to a complex type"}, 41 | {"func(a, b, c)", true, termFunctionCall, false, "function with identifier args"}, 42 | 43 | // Invalid 44 | {"system.someOtherFunc", false, termFunctionCall, true, "invalid qualified function"}, 45 | {"func", false, termFunctionCall, true, "invalid function"}, 46 | {"[1, 2, 3", false, termListLiteral, true, "invalid list literal"}, 47 | {"{1, 2, 3", false, termSetMapUdtLiteral, true, "invalid set literal"}, 48 | {"{ a: 1, a.b: 2, c: 3", false, termSetMapUdtLiteral, true, "invalid UDT literal"}, 49 | {"{ 'a': 1, 'b': 2, 'c': 3", false, termSetMapUdtLiteral, true, "invalid map literal"}, 50 | {"+123", false, termInvalid, true, "invalid term"}, 51 | 52 | // Not idempotent 53 | {"system.now()", false, termFunctionCall, false, "qualified 'now()' function"}, 54 | {"system.uuid()", false, termFunctionCall, false, "qualified 'uuid()' function "}, 55 | {"(uuid)now()", false, termCast, false, "cast of the 'now()' function"}, 56 | {"now(a, b, c, '1', 0)", false, termFunctionCall, false, "'now()' function w/ args"}, 57 | {"[now(), 2, 3]", false, termListLiteral, false, "list literal with 'now()' function"}, 58 | {"{now():'a'}", false, termSetMapUdtLiteral, false, "map literal with 'now()' function"}, 59 | } 60 | 61 | for _, tt := range tests { 62 | var l lexer 63 | l.init(tt.term) 64 | idempotent, typ, err := parseTerm(&l, l.next()) 65 | assert.Equal(t, tt.idempotent, idempotent, tt.msg) 66 | assert.Equal(t, tt.typ, typ, tt.msg) 67 | assert.True(t, (err != nil) == tt.hasError, tt.msg) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /parser/parse_update.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | import "errors" 18 | 19 | // Determines if an update statement is idempotent. 20 | // 21 | // An update statement not idempotent if: 22 | // * it contains an update operation that appends/prepends to a list or updates a counter 23 | // * uses a lightweight transaction (LWT) e.g. 'IF EXISTS' or 'IF a > 0' 24 | // * has an update operation or relation that uses a non-idempotent function e.g. now() or uuid() 25 | // 26 | // updateStatement: 'UPDATE' tableName usingClause? 'SET' updateOperations whereClause 'IF' ( 'EXISTS' | conditions )? 27 | // tableName: ( identifier '.' )? identifier 28 | // 29 | func isIdempotentUpdateStmt(l *lexer) (idempotent bool, t token, err error) { 30 | t = l.next() 31 | if tkIdentifier != t { 32 | return false, tkInvalid, errors.New("expected identifier after 'UPDATE' in update statement") 33 | } 34 | 35 | _, _, t, err = parseQualifiedIdentifier(l) 36 | if err != nil { 37 | return false, tkInvalid, err 38 | } 39 | 40 | t, err = parseUsingClause(l, t) 41 | if err != nil { 42 | return false, tkInvalid, err 43 | } 44 | 45 | for !isUnreservedKeyword(l, t, "set") { 46 | return false, tkInvalid, errors.New("expected 'SET' in update statement") 47 | } 48 | 49 | for t = l.next(); tkIf != t && tkWhere != t && !isDMLTerminator(t); t = skipToken(l, l.next(), tkComma) { 50 | idempotent, err = parseUpdateOp(l, t) 51 | if !idempotent { 52 | return idempotent, tkInvalid, err 53 | } 54 | } 55 | 56 | if tkWhere == t { 57 | idempotent, t, err = parseWhereClause(l) 58 | if !idempotent { 59 | return idempotent, tkInvalid, err 60 | } 61 | } 62 | 63 | for ; !isDMLTerminator(t); t = l.next() { 64 | if tkIf == t { 65 | return false, tkInvalid, nil 66 | } 67 | } 68 | return true, t, nil 69 | } 70 | 71 | // Parse over using clause. 72 | // 73 | // usingClause 74 | // : 'USING' timestamp 75 | // | 'USING' ttl 76 | // | 'USING' timestamp 'AND' ttl 77 | // | 'USING' ttl 'AND' timestamp 78 | // 79 | func parseUsingClause(l *lexer, t token) (next token, err error) { 80 | if tkUsing == t { 81 | err = parseTtlOrTimestamp(l) 82 | if err != nil { 83 | return tkInvalid, err 84 | } 85 | if t = l.next(); tkAnd == t { 86 | err = parseTtlOrTimestamp(l) 87 | if err != nil { 88 | return tkInvalid, err 89 | } 90 | return l.next(), nil 91 | } 92 | } 93 | return t, nil 94 | } 95 | 96 | // Parse over TTL or timestamp 97 | // 98 | // timestamp: 'TIMESTAMP' ( INTEGER | bindMarker ) 99 | // ttl: 'TTL' ( INTEGER | bindMarker ) 100 | // 101 | func parseTtlOrTimestamp(l *lexer) error { 102 | var t token 103 | if t = l.next(); !isUnreservedKeyword(l, t, "ttl") && !isUnreservedKeyword(l, t, "timestamp") { 104 | return errors.New("expected 'TTL' or 'TIMESTAMP' after 'USING'") 105 | } 106 | switch t = l.next(); t { 107 | case tkInteger: 108 | return nil 109 | case tkColon, tkQMark: 110 | return parseBindMarker(l, t) 111 | } 112 | return errors.New("expected integer or bind marker after 'TTL' or 'TIMESTAMP'") 113 | } 114 | -------------------------------------------------------------------------------- /parser/parse_update_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | import ( 18 | "testing" 19 | 20 | "github.com/stretchr/testify/assert" 21 | ) 22 | 23 | func TestIsIdempotentUpdateStmt(t *testing.T) { 24 | var tests = []struct { 25 | query string 26 | idempotent bool 27 | hasError bool 28 | msg string 29 | }{ 30 | {"UPDATE table SET a = 0", true, false, "simple table"}, 31 | {"UPDATE table SET a = 0;", true, false, "semicolong"}, 32 | {"UPDATE table SET a = 0, b = 0", true, false, "multiple updates"}, 33 | {"UPDATE ks.table SET a = 0", true, false, "simple qualified table"}, 34 | {"UPDATE ks.table USING TIMESTAMP 1234 SET a = 0", true, false, "using timestamp"}, 35 | {"UPDATE ks.table USING TIMESTAMP 1234 AND TTL 5678 SET a = 0", true, false, "using timestamp and ttl"}, 36 | {"UPDATE ks.table USING TTL 1234 SET a = 0", true, false, "using ttl"}, 37 | {"UPDATE ks.table USING TTL 1234 AND TIMESTAMP 5678 SET a = 0", true, false, "using ttl and timestamp"}, 38 | {"UPDATE ks.table SET a = 0 WHERE a > 100", true, false, "where clause"}, 39 | 40 | // Invalid 41 | {"UPDATE table", false, true, "no 'SET'"}, 42 | {"UPDATE table a = 0", false, true, "no 'SET' w/ update operation"}, 43 | {"UPDATE table a = 0 WHERE", false, true, "where clause no relations"}, 44 | {"UPDATE table SET a = 0,", true, false, "multiple updates no operation"}, 45 | {"UPDATE ks. SET a = 0", false, true, "no table"}, 46 | {"UPDATE table USING SET a = 0", false, true, "no timestamp or ttl"}, 47 | {"UPDATE table USING TTL SET a = 0", false, true, "ttl no value"}, 48 | {"UPDATE table USING TTL 1234 AND SET a = 0", false, true, "no ttl/timestamp after 'AND' in using clause"}, 49 | 50 | // Not idempotent 51 | {"UPDATE table SET a = now()", false, false, "simple w/ now()"}, 52 | {"UPDATE table SET a = a + 1", false, false, "add to counter"}, 53 | {"UPDATE table USING TIMESTAMP 1234 SET a = now()", false, false, "using clause w/ now()"}, 54 | {"UPDATE table SET a = 0 WHERE a > toTimestamp(toDate(now()))", false, false, "where clause w/ now()"}, 55 | {"UPDATE table SET a = 0 WHERE a > 0 IF EXISTS", false, false, "where clause w/ LWT"}, 56 | {"UPDATE table SET a = 0 IF EXISTS", false, false, "LWT"}, 57 | {"UPDATE table SET a = 0 IF EXISTS;", false, false, "LWT and semicolon"}, 58 | } 59 | 60 | for _, tt := range tests { 61 | idempotent, err := IsQueryIdempotent(tt.query) 62 | assert.True(t, (err != nil) == tt.hasError, tt.msg) 63 | assert.Equal(t, tt.idempotent, idempotent, "invalid idempotency", tt.msg) 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /parser/parse_updateop.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | import "errors" 18 | 19 | // Determines if an update operation is idempotent. 20 | // 21 | // Non-idempotent update operations include: 22 | // * Using a non-idempotent function e.g. now(), uuid() 23 | // * Prepends or appends to a list type 24 | // * Increments or decrements a counter 25 | // 26 | // Important: There are currently some ambiguous cases where if the type is not known we cannot correctly 27 | // determine if an operation is idempotent. These include: 28 | // * Using a bind marker (this could be fixed for prepared statements using the prepared metadata) 29 | // * Function calls 30 | // 31 | // updateOperation 32 | // : identifier '=' term ( '+' identifier )? 33 | // | identifier '=' identifier ( '+' | '-' ) term 34 | // | identifier ( '+=' | '-=' ) term 35 | // | identifier '[' term ']' '=' term 36 | // | identifier '.' identifier '=' term 37 | // 38 | func parseUpdateOp(l *lexer, t token) (idempotent bool, err error) { 39 | if tkIdentifier != t { 40 | return false, errors.New("expected identifier after 'SET' in update statement") 41 | } 42 | 43 | var typ termType 44 | 45 | switch t = l.next(); t { 46 | case tkEqual: 47 | l.mark() 48 | maybeId, maybeAddOrSub := l.next(), l.next() 49 | if tkIdentifier == maybeId && (tkAdd == maybeAddOrSub || tkSub == maybeAddOrSub) { // identifier = identifier + term | identifier = identifier - term 50 | t = l.next() 51 | if idempotent, typ, err = parseTerm(l, t); !idempotent { 52 | return idempotent, err 53 | } 54 | return isIdempotentUpdateOpTermType(typ), nil 55 | 56 | } else { 57 | l.rewind() 58 | t = l.next() 59 | if idempotent, typ, err = parseTerm(l, t); idempotent { // identifier = term | identifier = term + identifier 60 | l.mark() 61 | if t = l.next(); tkAdd == t { 62 | if tkIdentifier != l.next() { 63 | return false, errors.New("expected identifier after '+' operator in update operation") 64 | } 65 | return isIdempotentUpdateOpTermType(typ), nil 66 | } else { 67 | l.rewind() 68 | } 69 | } else { 70 | return idempotent, err 71 | } 72 | } 73 | case tkAddEqual, tkSubEqual: // identifier += term | identifier -= term 74 | t = l.next() 75 | if idempotent, typ, err = parseTerm(l, t); !idempotent { 76 | return idempotent, err 77 | } 78 | return isIdempotentUpdateOpTermType(typ), nil 79 | case tkLsquare: // identifier '[' term ']' = term 80 | if idempotent, _, err = parseTerm(l, l.next()); !idempotent { 81 | return idempotent, err 82 | } 83 | if tkRsquare != l.next() { 84 | return false, errors.New("expected closing ']' in update operation") 85 | } 86 | if tkEqual != l.next() { 87 | return false, errors.New("expected '=' in update operation") 88 | } 89 | if idempotent, _, err = parseTerm(l, l.next()); !idempotent { 90 | return idempotent, err 91 | } 92 | case tkDot: // identifier '.' identifier '=' term 93 | if tkIdentifier != l.next() { 94 | return false, errors.New("expected identifier after '.' in update operation") 95 | } 96 | if tkEqual != l.next() { 97 | return false, errors.New("expected '=' in update operation") 98 | } 99 | if idempotent, _, err = parseTerm(l, l.next()); !idempotent { 100 | return idempotent, err 101 | } 102 | default: 103 | return false, errors.New("unexpected token in update operation") 104 | } 105 | 106 | return true, nil 107 | } 108 | 109 | // Update terms can be one of the following: 110 | // * Literal (idempotent, if not a list) 111 | // * Bind marker (ambiguous, so not idempotent) 112 | // * Function call (ambiguous, so not idempotent) 113 | // * Type cast (probably not idempotent) 114 | func isIdempotentUpdateOpTermType(typ termType) bool { 115 | return typ == termSetMapUdtLiteral || typ == termTupleLiteral 116 | } 117 | -------------------------------------------------------------------------------- /parser/parse_updateop_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | import ( 18 | "testing" 19 | 20 | "github.com/stretchr/testify/assert" 21 | ) 22 | 23 | func TestParseUpdateOp(t *testing.T) { 24 | var tests = []struct { 25 | updateOp string 26 | idempotent bool 27 | hasError bool 28 | msg string 29 | }{ 30 | {"a = 0", true, false, "simple update operation"}, 31 | {"a['a'] = 0", true, false, "assign to collection"}, 32 | {"a[0] = 0", true, false, "assign to collection (integer index)"}, 33 | {"a.b = 0", true, false, "assign to UDT field"}, 34 | {"a = [1, 2, 3]", true, false, "assign list"}, 35 | {"a = {1, 2, 3}", true, false, "assign set"}, 36 | {"a = {'a': 1, 'b': 2, 'c': 3}", true, false, "assign map"}, 37 | 38 | {"a = a + {1, 2}", true, false, "insert into set"}, 39 | {"a += {1, 2}", true, false, "insert (assign) into set"}, 40 | {"a = a - {1, 2}", true, false, "remove from set"}, 41 | {"a -= {1, 2}", true, false, "remove (assign) from set"}, 42 | 43 | {"a = a + {'a': 1, 'b': 2}", true, false, "insert into map"}, 44 | {"a += {'a': 1, 'b': 2}", true, false, "insert (assign) into map"}, 45 | {"a = a - {'a': 1, 'b': 2}", true, false, "remove from map"}, 46 | {"a -= {'a': 1, 'b': 2}", true, false, "remove (assign) from map"}, 47 | 48 | {"a = a + (1, 'a')", true, false, "insert into tuple"}, 49 | {"a += (1, 'a')", true, false, "insert (assign) into tuple"}, 50 | {"a = a - (1, 'a')", true, false, "remove from tuple"}, 51 | {"a -= (1, 'a')", true, false, "remove (assign) from tuple"}, 52 | 53 | // Invalid 54 | {"0 = a", false, true, "start w/ term"}, 55 | {"a[0 = a", false, true, "no closing square bracket"}, 56 | {"a0] = a", false, true, "no opening square bracket"}, 57 | {"a. = a", false, true, "no identifier after '.' for UDT field"}, 58 | 59 | {"a = a +", false, true, "add with no term"}, 60 | {"a = a -", false, true, "subtract with no term"}, 61 | {"a = 1 +", false, true, "add with no identifier"}, 62 | {"a +=", false, true, "add assign with no term"}, 63 | {"a -=", false, true, "subtract assign with no term"}, 64 | 65 | // Not idempotent 66 | {"a = now()", false, false, "simple update operation w/ now()"}, 67 | {"a[0] = now()", false, false, "assign to collection (integer index) w/ now()"}, 68 | {"a[now()] = 0", false, false, "assign to collection w/ now() index"}, 69 | {"a.b = now()", false, false, "assign to UDT field with now()"}, 70 | 71 | {"a = a + 1", false, false, "add to counter"}, 72 | {"a += 1", false, false, "add assign to counter"}, 73 | {"a = 1 + a", false, false, "add to counter swap term"}, 74 | {"a = a - 1", false, false, "subtract from counter"}, 75 | {"a -= 1", false, false, "subtract assign from counter"}, 76 | 77 | {"a = a + [1, 2]", false, false, "append to list"}, 78 | {"a += [1, 2]", false, false, "append assign to list"}, 79 | {"a = [1, 2] + a", false, false, "prepend to list"}, 80 | 81 | {"a = a - [1, 2]", false, false, "remove from list"}, 82 | {"a -= [1, 2]", false, false, "remove assign to list"}, 83 | 84 | {"a = a + (int)1", false, false, "add/append w/ cast"}, 85 | {"a += (int)1", false, false, "add/append assign w/ cast"}, 86 | {"a = (int)1 + a", false, false, "add/append (swap term) w/ cast"}, 87 | {"a = a - (int)1", false, false, "subtract/remove w/ cast"}, 88 | {"a -= (int)1", false, false, "subtract/remove assign w/ cast"}, 89 | 90 | // Ambiguous 91 | {"a = a + ?", false, false, "add/append w/ bind marker"}, 92 | {"a += ?", false, false, "add/append assign w/ bind marker"}, 93 | {"a = ? + a", false, false, "add/append (swap term) w/ bind marker"}, 94 | {"a = a - ?", false, false, "subtract/remove w/ bind marker"}, 95 | {"a -= ?", false, false, "subtract/remove assign w/ bind marker"}, 96 | 97 | {"a = a + :name", false, false, "add/append w/ named bind marker"}, 98 | {"a += :name", false, false, "add/append assign w/ named bind marker"}, 99 | {"a = :name + a", false, false, "add/append (swap term) w/ named bind marker"}, 100 | {"a = a - :name", false, false, "subtract/remove w/ named bind marker"}, 101 | {"a -= :name", false, false, "subtract/remove assign w/ named bind marker"}, 102 | 103 | {"a = a + func()", false, false, "add/append w/ function"}, 104 | {"a += func()", false, false, "add/append assign w/ function"}, 105 | {"a = func() + a", false, false, "add/append (swap term) w/ function"}, 106 | {"a = a - func()", false, false, "subtract/remove w/ function"}, 107 | {"a -= func()", false, false, "subtract/remove assign w/ function"}, 108 | } 109 | 110 | for _, tt := range tests { 111 | var l lexer 112 | l.init(tt.updateOp) 113 | idempotent, err := parseUpdateOp(&l, l.next()) 114 | assert.Equal(t, tt.idempotent, idempotent, tt.msg) 115 | assert.True(t, (err != nil) == tt.hasError, tt.msg) 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /parser/parser.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | //go:generate ragel -Z -G2 lexer.rl -o lexer.go 16 | //go:generate go fmt lexer.go 17 | 18 | package parser 19 | 20 | import ( 21 | "errors" 22 | "fmt" 23 | "strings" 24 | 25 | "github.com/datastax/go-cassandra-native-protocol/datatype" 26 | "github.com/datastax/go-cassandra-native-protocol/message" 27 | "github.com/google/uuid" 28 | ) 29 | 30 | type Selector interface { 31 | Values(columns []*message.ColumnMetadata, valueFunc ValueLookupFunc) (filtered []message.Column, err error) 32 | Columns(columns []*message.ColumnMetadata, stmt *SelectStatement) (filtered []*message.ColumnMetadata, err error) 33 | } 34 | 35 | type AliasSelector struct { 36 | Selector Selector 37 | Alias string 38 | } 39 | 40 | func (a AliasSelector) Values(columns []*message.ColumnMetadata, valueFunc ValueLookupFunc) (filtered []message.Column, err error) { 41 | return a.Selector.Values(columns, valueFunc) 42 | } 43 | 44 | func (a AliasSelector) Columns(columns []*message.ColumnMetadata, stmt *SelectStatement) (filtered []*message.ColumnMetadata, err error) { 45 | cols, err := a.Selector.Columns(columns, stmt) 46 | if err != nil { 47 | return 48 | } 49 | for _, column := range cols { 50 | alias := *column // Make a copy so we can modify the name 51 | alias.Name = a.Alias 52 | filtered = append(filtered, &alias) 53 | } 54 | return 55 | } 56 | 57 | type IDSelector struct { 58 | Name string 59 | } 60 | 61 | func (i IDSelector) Values(_ []*message.ColumnMetadata, valueFunc ValueLookupFunc) (filtered []message.Column, err error) { 62 | value, err := valueFunc(i.Name) 63 | if err != nil { 64 | return 65 | } 66 | return []message.Column{value}, err 67 | } 68 | 69 | func (i IDSelector) Columns(columns []*message.ColumnMetadata, stmt *SelectStatement) (filtered []*message.ColumnMetadata, err error) { 70 | if column := FindColumnMetadata(columns, i.Name); column != nil { 71 | return []*message.ColumnMetadata{column}, nil 72 | } else { 73 | return nil, fmt.Errorf("invalid column %s", i.Name) 74 | } 75 | } 76 | 77 | type StarSelector struct{} 78 | 79 | func (s StarSelector) Values(columns []*message.ColumnMetadata, valueFunc ValueLookupFunc) (filtered []message.Column, err error) { 80 | for _, column := range columns { 81 | var val message.Column 82 | val, err = valueFunc(column.Name) 83 | if err != nil { 84 | return 85 | } 86 | filtered = append(filtered, val) 87 | } 88 | return 89 | } 90 | 91 | func (s StarSelector) Columns(columns []*message.ColumnMetadata, _ *SelectStatement) (filtered []*message.ColumnMetadata, err error) { 92 | filtered = columns 93 | return 94 | } 95 | 96 | type CountFuncSelector struct { 97 | Arg string 98 | } 99 | 100 | func (s CountFuncSelector) Values(_ []*message.ColumnMetadata, valueFunc ValueLookupFunc) (filtered []message.Column, err error) { 101 | val, err := valueFunc(CountValueName) 102 | if err != nil { 103 | return 104 | } 105 | filtered = append(filtered, val) 106 | return 107 | } 108 | 109 | func (s CountFuncSelector) Columns(_ []*message.ColumnMetadata, stmt *SelectStatement) (filtered []*message.ColumnMetadata, err error) { 110 | name := "count" 111 | if s.Arg != "*" { 112 | name = fmt.Sprintf("system.count(%s)", strings.ToLower(s.Arg)) 113 | } 114 | return []*message.ColumnMetadata{{ 115 | Keyspace: stmt.Keyspace, 116 | Table: stmt.Table, 117 | Name: name, 118 | Type: datatype.Int, 119 | }}, nil 120 | } 121 | 122 | type NowFuncSelector struct{} 123 | 124 | func (s NowFuncSelector) Values(_ []*message.ColumnMetadata, _ ValueLookupFunc) (filtered []message.Column, err error) { 125 | u, err := uuid.NewUUID() 126 | if err != nil { 127 | return 128 | } 129 | filtered = append(filtered, u[:]) 130 | return 131 | } 132 | 133 | func (s NowFuncSelector) Columns(_ []*message.ColumnMetadata, stmt *SelectStatement) (filtered []*message.ColumnMetadata, err error) { 134 | return []*message.ColumnMetadata{{ 135 | Keyspace: stmt.Keyspace, 136 | Table: stmt.Table, 137 | Name: "system.now()", 138 | Type: datatype.Timeuuid, 139 | }}, nil 140 | } 141 | type Statement interface { 142 | isStatement() 143 | } 144 | 145 | type SelectStatement struct { 146 | Keyspace string 147 | Table string 148 | Selectors []Selector 149 | } 150 | 151 | func (s SelectStatement) isStatement() {} 152 | 153 | type UseStatement struct { 154 | Keyspace string 155 | } 156 | 157 | func (u UseStatement) isStatement() {} 158 | 159 | var defaultSelectStatement = &SelectStatement{} 160 | 161 | // IsQueryHandled parses the query string and determines if the query is handled by the proxy 162 | func IsQueryHandled(keyspace Identifier, query string) (handled bool, stmt Statement, err error) { 163 | var l lexer 164 | l.init(query) 165 | 166 | t := l.next() 167 | switch t { 168 | case tkSelect: 169 | handled, stmt, err = isHandledSelectStmt(&l, keyspace) 170 | if !handled { 171 | stmt = defaultSelectStatement 172 | } 173 | return 174 | case tkUse: 175 | return isHandledUseStmt(&l) 176 | } 177 | return 178 | } 179 | 180 | // IsQueryIdempotent parses the query string and determines if the query is idempotent 181 | func IsQueryIdempotent(query string) (idempotent bool, err error) { 182 | var l lexer 183 | l.init(query) 184 | return isIdempotentStmt(&l, l.next()) 185 | } 186 | 187 | func isIdempotentStmt(l *lexer, t token) (idempotent bool, err error) { 188 | switch t { 189 | case tkSelect: 190 | return true, nil 191 | case tkUse, tkCreate, tkAlter, tkDrop: 192 | return false, nil 193 | case tkInsert: 194 | idempotent, t, err = isIdempotentInsertStmt(l) 195 | case tkUpdate: 196 | idempotent, t, err = isIdempotentUpdateStmt(l) 197 | case tkDelete: 198 | idempotent, t, err = isIdempotentDeleteStmt(l) 199 | case tkBegin: 200 | return isIdempotentBatchStmt(l) 201 | default: 202 | return false, errors.New("invalid statement type") 203 | } 204 | return idempotent && (t == tkEOF || t == tkEOS), err 205 | } 206 | -------------------------------------------------------------------------------- /parser/parser_utils.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package parser 16 | 17 | import ( 18 | "errors" 19 | 20 | "github.com/datastax/go-cassandra-native-protocol/message" 21 | ) 22 | 23 | const ( 24 | CountValueName = "count(*)" 25 | ) 26 | 27 | var systemTables = []string{"local", "peers", "peers_v2", "schema_keyspaces", "schema_columnfamilies", "schema_columns", "schema_usertypes"} 28 | 29 | var nonIdempotentFuncs = []string{"uuid", "now"} 30 | 31 | type ValueLookupFunc func(name string) (value message.Column, err error) 32 | 33 | func FilterValues(stmt *SelectStatement, columns []*message.ColumnMetadata, valueFunc ValueLookupFunc) (filtered []message.Column, err error) { 34 | for _, selector := range stmt.Selectors { 35 | var vals []message.Column 36 | vals, err = selector.Values(columns, valueFunc) 37 | if err != nil { 38 | return nil, err 39 | } 40 | filtered = append(filtered, vals...) 41 | } 42 | return filtered, nil 43 | } 44 | 45 | func FilterColumns(stmt *SelectStatement, columns []*message.ColumnMetadata) (filtered []*message.ColumnMetadata, err error) { 46 | for _, selector := range stmt.Selectors { 47 | var cols []*message.ColumnMetadata 48 | cols, err = selector.Columns(columns, stmt) 49 | if err != nil { 50 | return nil, err 51 | } 52 | filtered = append(filtered, cols...) 53 | } 54 | return filtered, nil 55 | } 56 | 57 | func isSystemTable(name Identifier) bool { 58 | for _, table := range systemTables { 59 | if name.equal(table) { 60 | return true 61 | } 62 | } 63 | return false 64 | } 65 | 66 | func isNonIdempotentFunc(name Identifier) bool { 67 | for _, funcName := range nonIdempotentFuncs { 68 | if name.equal(funcName) { 69 | return true 70 | } 71 | } 72 | return false 73 | } 74 | 75 | func isUnreservedKeyword(l *lexer, t token, keyword string) bool { 76 | return tkIdentifier == t && l.identifier().equal(keyword) 77 | } 78 | 79 | func skipToken(l *lexer, t token, toSkip token) token { 80 | if t == toSkip { 81 | return l.next() 82 | } 83 | return t 84 | } 85 | 86 | func untilToken(l *lexer, to token) token { 87 | var t token 88 | for to != t && tkEOF != t { 89 | t = l.next() 90 | } 91 | return t 92 | } 93 | 94 | func parseQualifiedIdentifier(l *lexer) (keyspace, target Identifier, t token, err error) { 95 | temp := l.identifier() 96 | if t = l.next(); tkDot == t { 97 | if t = l.next(); tkIdentifier != t { 98 | return Identifier{}, Identifier{}, tkInvalid, errors.New("expected another identifier after '.' for qualified identifier") 99 | } 100 | return temp, l.identifier(), l.next(), nil 101 | } else { 102 | return Identifier{}, temp, t, nil 103 | } 104 | } 105 | 106 | func parseIdentifiers(l *lexer, t token) (err error) { 107 | for tkRparen != t && tkEOF != t { 108 | if tkIdentifier != t { 109 | return errors.New("expected identifier") 110 | } 111 | t = skipToken(l, l.next(), tkComma) 112 | } 113 | if tkRparen != t { 114 | return errors.New("expected closing ')' for identifiers") 115 | } 116 | return nil 117 | } 118 | 119 | func isDMLTerminator(t token) bool { 120 | return t == tkEOF || t == tkEOS || t == tkInsert || t == tkUpdate || t == tkDelete || t == tkApply 121 | } 122 | -------------------------------------------------------------------------------- /proxy.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package main 16 | 17 | import ( 18 | "context" 19 | "os" 20 | "os/signal" 21 | 22 | "github.com/datastax/cql-proxy/proxy" 23 | ) 24 | 25 | func main() { 26 | ctx, cancel := signalContext(context.Background(), os.Interrupt, os.Kill) 27 | 28 | defer cancel() 29 | 30 | os.Exit(proxy.Run(ctx, os.Args[1:])) 31 | } 32 | 33 | // signalContext is a simplified version of `signal.NotifyContext()` for golang 1.15 and earlier 34 | func signalContext(parent context.Context, sig ...os.Signal) (context.Context, func()) { 35 | ctx, cancel := context.WithCancel(parent) 36 | ch := make(chan os.Signal) 37 | signal.Notify(ch, sig...) 38 | if ctx.Err() == nil { 39 | go func() { 40 | select { 41 | case <-ch: 42 | cancel() 43 | case <-ctx.Done(): 44 | } 45 | }() 46 | } 47 | return ctx, func() { 48 | cancel() 49 | signal.Stop(ch) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /proxy/request.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxy 16 | 17 | import ( 18 | "errors" 19 | "io" 20 | "reflect" 21 | "sync" 22 | 23 | "github.com/datastax/cql-proxy/codecs" 24 | "github.com/datastax/cql-proxy/parser" 25 | "github.com/datastax/cql-proxy/proxycore" 26 | "github.com/datastax/go-cassandra-native-protocol/primitive" 27 | 28 | "github.com/datastax/go-cassandra-native-protocol/frame" 29 | "github.com/datastax/go-cassandra-native-protocol/message" 30 | "go.uber.org/zap" 31 | ) 32 | 33 | type idempotentState int 34 | 35 | const ( 36 | notDetermined idempotentState = iota 37 | notIdempotent 38 | isIdempotent 39 | ) 40 | 41 | type request struct { 42 | client *client 43 | session *proxycore.Session 44 | state idempotentState 45 | keyspace string 46 | msg message.Message 47 | done bool 48 | retryCount int 49 | host *proxycore.Host 50 | stream int16 51 | version primitive.ProtocolVersion 52 | qp proxycore.QueryPlan 53 | frm interface{} 54 | isSelect bool // Only used for prepared statements currently 55 | mu sync.Mutex 56 | } 57 | 58 | func (r *request) Execute(next bool) { 59 | r.mu.Lock() 60 | r.executeInternal(next) 61 | r.mu.Unlock() 62 | } 63 | 64 | // lock before using 65 | func (r *request) executeInternal(next bool) { 66 | for !r.done { 67 | if next { 68 | r.host = r.qp.Next() 69 | } 70 | if r.host == nil { 71 | r.done = true 72 | r.send(&message.ServerError{ErrorMessage: "Proxy exhausted query plan and there are no more hosts available to try"}) 73 | } else { 74 | err := r.session.Send(r.host, r) 75 | if err == nil { 76 | break 77 | } else { 78 | r.client.proxy.logger.Debug("failed to send request to host", zap.Stringer("host", r.host), zap.Error(err)) 79 | } 80 | } 81 | } 82 | } 83 | 84 | func (r *request) send(msg message.Message) { 85 | _ = r.client.conn.Write(proxycore.SenderFunc(func(writer io.Writer) error { 86 | return r.client.codec.EncodeFrame(frame.NewFrame(r.version, r.stream, msg), writer) 87 | })) 88 | } 89 | 90 | func (r *request) sendRaw(raw *frame.RawFrame) { 91 | raw.Header.StreamId = r.stream 92 | _ = r.client.conn.Write(proxycore.SenderFunc(func(writer io.Writer) error { 93 | return r.client.codec.EncodeRawFrame(raw, writer) 94 | })) 95 | } 96 | 97 | func (r *request) Frame() interface{} { 98 | return r.frm 99 | } 100 | func (r *request) IsPrepareRequest() bool { 101 | _, isPrepare := r.msg.(*message.Prepare) 102 | return isPrepare 103 | } 104 | 105 | func (r *request) checkIdempotent() bool { 106 | if notDetermined == r.state { 107 | idempotent := false 108 | var err error 109 | if r.msg != nil { 110 | switch msg := r.msg.(type) { 111 | case *codecs.PartialQuery: 112 | idempotent, err = parser.IsQueryIdempotent(msg.Query) 113 | case *codecs.PartialExecute: 114 | idempotent = r.client.proxy.isIdempotent(msg.QueryId) 115 | case *codecs.PartialBatch: 116 | idempotent, err = r.isBatchIdempotent(msg) 117 | default: 118 | r.client.proxy.logger.Error("invalid message type encountered when checking for idempotence", 119 | zap.Stringer("type", reflect.TypeOf(msg))) 120 | } 121 | } 122 | if err != nil { 123 | r.client.proxy.logger.Error("error parsing query for idempotence", 124 | zap.Error(err), 125 | zap.Stringer("type", reflect.TypeOf(r.msg))) 126 | } 127 | if idempotent { 128 | r.state = isIdempotent 129 | } else { 130 | r.state = notIdempotent 131 | } 132 | } 133 | return isIdempotent == r.state 134 | } 135 | 136 | func (r *request) OnClose(_ error) { 137 | r.mu.Lock() 138 | defer r.mu.Unlock() 139 | 140 | if r.checkIdempotent() { 141 | r.executeInternal(true) 142 | } else { 143 | if !r.done { 144 | r.done = true 145 | r.send(&message.ServerError{ErrorMessage: "Proxy is unable to retry non-idempotent query after connection to backend cluster closed"}) 146 | } 147 | } 148 | } 149 | 150 | func (r *request) OnResult(raw *frame.RawFrame) { 151 | r.mu.Lock() 152 | defer r.mu.Unlock() 153 | if !r.done { 154 | if raw.Header.OpCode != primitive.OpCodeError || 155 | !r.handleErrorResult(raw) { // If the error result is retried then we don't send back this response 156 | r.client.maybeStorePreparedMetadata(raw, r.isSelect, r.msg) 157 | r.done = true 158 | r.sendRaw(raw) 159 | } 160 | } 161 | } 162 | 163 | func (r *request) handleErrorResult(raw *frame.RawFrame) (retried bool) { 164 | retried = false 165 | logger := r.client.proxy.logger 166 | decision := ReturnError 167 | 168 | frm, err := r.client.codec.ConvertFromRawFrame(raw) 169 | if err != nil { 170 | logger.Error("unable to decode error frame for retry decision", zap.Error(err)) 171 | } else { 172 | errMsg := frm.Body.Message.(message.Error) 173 | 174 | logger.Debug("received error response", 175 | zap.Stringer("host", r.host), 176 | zap.Stringer("errorCode", errMsg.GetErrorCode()), 177 | zap.String("error", errMsg.GetErrorMessage()), 178 | ) 179 | switch msg := frm.Body.Message.(type) { 180 | case *message.ReadTimeout: 181 | decision = r.client.proxy.config.RetryPolicy.OnReadTimeout(msg, r.retryCount) 182 | if decision != ReturnError { 183 | logger.Debug("retrying read timeout", 184 | zap.Stringer("decision", decision), 185 | zap.Stringer("response", msg), 186 | zap.Int("retryCount", r.retryCount), 187 | ) 188 | } 189 | case *message.WriteTimeout: 190 | if r.checkIdempotent() { 191 | decision = r.client.proxy.config.RetryPolicy.OnWriteTimeout(msg, r.retryCount) 192 | if decision != ReturnError { 193 | logger.Debug("retrying write timeout", 194 | zap.Stringer("decision", decision), 195 | zap.Stringer("response", msg), 196 | zap.Int("retryCount", r.retryCount), 197 | ) 198 | } 199 | } 200 | case *message.Unavailable: 201 | decision = r.client.proxy.config.RetryPolicy.OnUnavailable(msg, r.retryCount) 202 | if decision != ReturnError { 203 | logger.Debug("retrying on unavailable error", 204 | zap.Stringer("decision", decision), 205 | zap.Stringer("response", msg), 206 | zap.Int("retryCount", r.retryCount), 207 | ) 208 | } 209 | case *message.IsBootstrapping: 210 | decision = RetryNext 211 | logger.Debug("retrying on bootstrapping error", 212 | zap.Stringer("decision", decision), 213 | zap.Int("retryCount", r.retryCount), 214 | ) 215 | case *message.ServerError, *message.Overloaded, *message.TruncateError, 216 | *message.ReadFailure, *message.WriteFailure: 217 | if r.checkIdempotent() { 218 | decision = r.client.proxy.config.RetryPolicy.OnErrorResponse(errMsg, r.retryCount) 219 | if decision != ReturnError { 220 | logger.Debug("retrying on error response", 221 | zap.Stringer("decision", decision), 222 | zap.Int("retryCount", r.retryCount), 223 | ) 224 | } 225 | } 226 | default: 227 | // Do nothing, return the error 228 | } 229 | 230 | switch decision { 231 | case RetryNext: 232 | r.retryCount++ 233 | r.executeInternal(true) 234 | retried = true 235 | case RetrySame: 236 | r.retryCount++ 237 | r.executeInternal(false) 238 | retried = true 239 | default: 240 | // Do nothing, return the error 241 | } 242 | } 243 | 244 | return retried 245 | } 246 | 247 | func (r *request) isBatchIdempotent(batch *codecs.PartialBatch) (idempotent bool, err error) { 248 | for _, query := range batch.Queries { 249 | switch q := query.QueryOrId.(type) { 250 | case string: 251 | if idempotent, err = parser.IsQueryIdempotent(q); !idempotent { 252 | return idempotent, err 253 | } 254 | case []byte: 255 | idempotent = r.client.proxy.isIdempotent(q) 256 | if !idempotent { 257 | return false, nil 258 | } 259 | default: 260 | return false, errors.New("unhandled query type in batch") 261 | } 262 | } 263 | return true, nil 264 | } 265 | -------------------------------------------------------------------------------- /proxy/retrypolicy.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxy 16 | 17 | import ( 18 | "github.com/datastax/go-cassandra-native-protocol/message" 19 | "github.com/datastax/go-cassandra-native-protocol/primitive" 20 | ) 21 | 22 | // RetryDecision is a type used for deciding what to do when a request has failed. 23 | type RetryDecision int 24 | 25 | func (r RetryDecision) String() string { 26 | switch r { 27 | case RetrySame: 28 | return "retry same node" 29 | case RetryNext: 30 | return "retry next node" 31 | case ReturnError: 32 | return "returning error" 33 | } 34 | return "unknown" 35 | } 36 | 37 | const ( 38 | // RetrySame should be returned when a request should be retried on the same host. 39 | RetrySame RetryDecision = iota 40 | // RetryNext should be returned when a request should be retried on the next host according to the request's query 41 | // plan. 42 | RetryNext 43 | // ReturnError should be returned when a request's original error should be forwarded along to the client. 44 | ReturnError 45 | ) 46 | 47 | // RetryPolicy is an interface for defining retry behavior when a server-side error occurs. 48 | type RetryPolicy interface { 49 | // OnReadTimeout handles the retry decision for a server-side read timeout error (Read_timeout = 0x1200). 50 | // This occurs when a replica read request times out during a read query. 51 | OnReadTimeout(msg *message.ReadTimeout, retryCount int) RetryDecision 52 | 53 | // OnWriteTimeout handles the retry decision for a server-side write timeout error (Write_timeout = 0x1100). 54 | // This occurs when a replica write request times out during a write query. 55 | OnWriteTimeout(msg *message.WriteTimeout, retryCount int) RetryDecision 56 | 57 | // OnUnavailable handles the retry decision for a server-side unavailable exception (Unavailable = 0x1000). 58 | // This occurs when a coordinator determines that there are not enough replicas to handle a query at the requested 59 | // consistency level. 60 | OnUnavailable(msg *message.Unavailable, retryCount int) RetryDecision 61 | 62 | // OnErrorResponse handles the retry decision for other potentially recoverable errors. 63 | // This can be called for the following error types: server error (ServerError = 0x0000), 64 | // overloaded (Overloaded = 0x1001), truncate error (Truncate_error = 0x1003), read failure (Read_failure = 0x1300), 65 | // and write failure (Write_failure = 0x1500). 66 | OnErrorResponse(msg message.Error, retryCount int) RetryDecision 67 | } 68 | 69 | type defaultRetryPolicy struct{} 70 | 71 | var defaultRetryPolicyInstance defaultRetryPolicy 72 | 73 | // NewDefaultRetryPolicy creates a new default retry policy. 74 | // The default retry policy takes a conservative approach to retrying requests. In most cases it retries only once in 75 | // cases where a retry is likely to succeed. 76 | func NewDefaultRetryPolicy() RetryPolicy { 77 | return &defaultRetryPolicyInstance 78 | } 79 | 80 | // OnReadTimeout retries in the case where there were enough replicas to satisfy the request, but one of the replicas 81 | // didn't respond with data and timed out. It's likely that a single retry to the same coordinator will succeed because 82 | // it will have recognized the replica as dead before the retry is attempted. 83 | // 84 | // In all other cases it will forward the original error to the client. 85 | func (d defaultRetryPolicy) OnReadTimeout(msg *message.ReadTimeout, retryCount int) RetryDecision { 86 | if retryCount == 0 && msg.Received >= msg.BlockFor && !msg.DataPresent { 87 | return RetrySame 88 | } else { 89 | return ReturnError 90 | } 91 | } 92 | 93 | // OnWriteTimeout retries in the case where a coordinator failed to write its batch log to a set of datacenter local 94 | // nodes. It's likely that a single retry to the same coordinator will succeed because it will have recognized the 95 | // dead nodes and use a different set of nodes. 96 | // 97 | // In all other cases it will forward the original error to the client. 98 | func (d defaultRetryPolicy) OnWriteTimeout(msg *message.WriteTimeout, retryCount int) RetryDecision { 99 | if retryCount == 0 && msg.WriteType == primitive.WriteTypeBatchLog { 100 | return RetrySame 101 | } else { 102 | return ReturnError 103 | } 104 | } 105 | 106 | // OnUnavailable retries, once, on the next coordinator in the query plan. This is to handle the case where a 107 | // coordinator is failing because it was partitioned from a set of its replicas. 108 | func (d defaultRetryPolicy) OnUnavailable(_ *message.Unavailable, retryCount int) RetryDecision { 109 | if retryCount == 0 { 110 | return RetryNext 111 | } else { 112 | return ReturnError 113 | } 114 | } 115 | 116 | // OnErrorResponse retries on the next coordinator for all error types except for read and write failures. 117 | func (d defaultRetryPolicy) OnErrorResponse(msg message.Error, retryCount int) RetryDecision { 118 | code := msg.GetErrorCode() 119 | if code == primitive.ErrorCodeReadFailure || code == primitive.ErrorCodeWriteFailure { 120 | return ReturnError 121 | } else { 122 | return RetryNext 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /proxy/retrypolicy_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxy 16 | 17 | import ( 18 | "testing" 19 | 20 | "github.com/datastax/go-cassandra-native-protocol/message" 21 | "github.com/datastax/go-cassandra-native-protocol/primitive" 22 | "github.com/stretchr/testify/assert" 23 | ) 24 | 25 | func TestDefaultRetryPolicy_OnUnavailable(t *testing.T) { 26 | var tests = []struct { 27 | msg *message.Unavailable 28 | decision RetryDecision 29 | retryCount int 30 | }{ 31 | {&message.Unavailable{Consistency: 0, Required: 0, Alive: 0}, RetryNext, 0}, // Never retried 32 | {&message.Unavailable{Consistency: 0, Required: 0, Alive: 0}, ReturnError, 1}, // Already retried once 33 | } 34 | 35 | policy := NewDefaultRetryPolicy() 36 | for _, tt := range tests { 37 | decision := policy.OnUnavailable(tt.msg, tt.retryCount) 38 | assert.Equal(t, tt.decision, decision) 39 | } 40 | } 41 | 42 | func TestDefaultRetryPolicy_OnReadTimeout(t *testing.T) { 43 | var tests = []struct { 44 | msg *message.ReadTimeout 45 | decision RetryDecision 46 | retryCount int 47 | }{ 48 | {&message.ReadTimeout{Consistency: 0, Received: 2, BlockFor: 2, DataPresent: false}, RetrySame, 0}, // Enough received with no data 49 | {&message.ReadTimeout{Consistency: 0, Received: 3, BlockFor: 2, DataPresent: false}, ReturnError, 1}, // Already retried once 50 | {&message.ReadTimeout{Consistency: 0, Received: 2, BlockFor: 2, DataPresent: true}, ReturnError, 0}, // Data was present 51 | } 52 | 53 | policy := NewDefaultRetryPolicy() 54 | for _, tt := range tests { 55 | decision := policy.OnReadTimeout(tt.msg, tt.retryCount) 56 | assert.Equal(t, tt.decision, decision) 57 | } 58 | } 59 | 60 | func TestDefaultRetryPolicy_OnWriteTimeout(t *testing.T) { 61 | var tests = []struct { 62 | msg *message.WriteTimeout 63 | decision RetryDecision 64 | retryCount int 65 | }{ 66 | {&message.WriteTimeout{Consistency: 0, Received: 0, BlockFor: 0, WriteType: primitive.WriteTypeBatchLog, Contentions: 0}, RetrySame, 0}, // Logged batch 67 | {&message.WriteTimeout{Consistency: 0, Received: 0, BlockFor: 0, WriteType: primitive.WriteTypeBatchLog, Contentions: 0}, ReturnError, 1}, // Logged batch, already retried once 68 | {&message.WriteTimeout{Consistency: 0, Received: 0, BlockFor: 0, WriteType: primitive.WriteTypeSimple, Contentions: 0}, ReturnError, 0}, // Not a logged batch 69 | } 70 | 71 | policy := NewDefaultRetryPolicy() 72 | for _, tt := range tests { 73 | decision := policy.OnWriteTimeout(tt.msg, tt.retryCount) 74 | assert.Equal(t, tt.decision, decision) 75 | } 76 | } 77 | 78 | func TestDefaultRetryPolicy_OnErrorResponse(t *testing.T) { 79 | var tests = []struct { 80 | msg message.Error 81 | decision RetryDecision 82 | retryCount int 83 | }{ 84 | {&message.WriteFailure{}, ReturnError, 0}, // Write failure 85 | {&message.ReadFailure{}, ReturnError, 0}, // Read failure 86 | {&message.TruncateError{}, RetryNext, 0}, // Truncate failure 87 | {&message.ServerError{}, RetryNext, 0}, // Server failure 88 | {&message.Overloaded{}, RetryNext, 0}, // Overloaded failure 89 | } 90 | 91 | policy := NewDefaultRetryPolicy() 92 | for _, tt := range tests { 93 | decision := policy.OnErrorResponse(tt.msg, tt.retryCount) 94 | assert.Equal(t, tt.decision, decision) 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /proxycore/auth.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxycore 16 | 17 | import ( 18 | "bytes" 19 | "fmt" 20 | 21 | "go.uber.org/zap" 22 | ) 23 | 24 | type Authenticator interface { 25 | InitialResponse(authenticator string, c *ClientConn) ([]byte, error) 26 | EvaluateChallenge(token []byte) ([]byte, error) 27 | Success(token []byte) error 28 | } 29 | 30 | type passwordAuth struct { 31 | authId string 32 | username string 33 | password string 34 | } 35 | 36 | const dseAuthenticator = "com.datastax.bdp.cassandra.auth.DseAuthenticator" 37 | const passwordAuthenticator = "org.apache.cassandra.auth.PasswordAuthenticator" 38 | const astraAuthenticator = "org.apache.cassandra.auth.AstraAuthenticator" 39 | 40 | func (d *passwordAuth) InitialResponse(authenticator string, c *ClientConn) ([]byte, error) { 41 | if authenticator == dseAuthenticator { 42 | return []byte("PLAIN"), nil 43 | } 44 | // We'll return a SASL response but if we're seeing an authenticator we're unfamiliar with at least log 45 | // that information here 46 | if (authenticator != passwordAuthenticator) && (authenticator != astraAuthenticator) { 47 | c.logger.Info("observed unknown authenticator, treating as SASL", 48 | zap.String("authenticator", authenticator)) 49 | } 50 | return d.makeToken(), nil 51 | } 52 | 53 | func (d *passwordAuth) EvaluateChallenge(token []byte) ([]byte, error) { 54 | if token == nil || bytes.Compare(token, []byte("PLAIN-START")) != 0 { 55 | return nil, fmt.Errorf("incorrect SASL challenge from server, expecting PLAIN-START, got: %v", string(token)) 56 | } 57 | return d.makeToken(), nil 58 | } 59 | 60 | func (d *passwordAuth) makeToken() []byte { 61 | token := bytes.NewBuffer(make([]byte, 0, len(d.authId)+len(d.username)+len(d.password)+2)) 62 | token.WriteString(d.authId) 63 | token.WriteByte(0) 64 | token.WriteString(d.username) 65 | token.WriteByte(0) 66 | token.WriteString(d.password) 67 | return token.Bytes() 68 | } 69 | 70 | func (d *passwordAuth) Success(_ []byte) error { 71 | return nil 72 | } 73 | 74 | func NewPasswordAuth(username string, password string) Authenticator { 75 | return &passwordAuth{ 76 | authId: "", 77 | username: username, 78 | password: password, 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /proxycore/cluster_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxycore 16 | 17 | import ( 18 | "context" 19 | "net" 20 | "testing" 21 | "time" 22 | 23 | "github.com/datastax/go-cassandra-native-protocol/primitive" 24 | "github.com/stretchr/testify/assert" 25 | "github.com/stretchr/testify/require" 26 | "go.uber.org/zap" 27 | ) 28 | 29 | func TestConnectCluster(t *testing.T) { 30 | logger, _ := zap.NewDevelopment() 31 | 32 | ctx, cancel := context.WithCancel(context.Background()) 33 | defer cancel() 34 | 35 | c := NewMockCluster(net.ParseIP("127.0.0.0"), 9042) 36 | 37 | err := c.Add(ctx, 1) 38 | require.NoError(t, err) 39 | 40 | err = c.Add(ctx, 2) 41 | require.NoError(t, err) 42 | 43 | err = c.Add(ctx, 3) 44 | require.NoError(t, err) 45 | 46 | cluster, err := ConnectCluster(ctx, ClusterConfig{ 47 | Version: primitive.ProtocolVersion4, 48 | Resolver: NewResolver("127.0.0.1:9042"), 49 | ReconnectPolicy: NewReconnectPolicyWithDelays(200*time.Millisecond, time.Second), 50 | ConnectTimeout: 10 * time.Second, 51 | Logger: logger, 52 | HeartBeatInterval: 30 * time.Second, 53 | IdleTimeout: 60 * time.Second, 54 | }) 55 | require.NoError(t, err) 56 | 57 | events := make(chan interface{}) 58 | 59 | err = cluster.Listen(ClusterListenerFunc(func(event Event) { 60 | events <- event 61 | })) 62 | require.NoError(t, err) 63 | 64 | wait := func() interface{} { 65 | timer := time.NewTimer(2 * time.Second) 66 | select { 67 | case <-timer.C: 68 | require.Fail(t, "timed out waiting for event") 69 | case event := <-events: 70 | return event 71 | } 72 | require.Fail(t, "event expected") 73 | return nil 74 | } 75 | 76 | event := wait() 77 | require.IsType(t, event, &BootstrapEvent{Hosts: nil}) 78 | 79 | c.Stop(1) 80 | event = wait() 81 | assert.Equal(t, event, &ReconnectEvent{&defaultEndpoint{addr: "127.0.0.2:9042"}}) 82 | 83 | c.Stop(2) 84 | event = wait() 85 | assert.Equal(t, event, &ReconnectEvent{&defaultEndpoint{addr: "127.0.0.3:9042"}}) 86 | 87 | err = c.Start(ctx, 1) 88 | require.NoError(t, err) 89 | 90 | c.Stop(3) 91 | event = wait() 92 | assert.Equal(t, event, &ReconnectEvent{&defaultEndpoint{addr: "127.0.0.1:9042"}}) 93 | } 94 | -------------------------------------------------------------------------------- /proxycore/conn.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxycore 16 | 17 | import ( 18 | "bufio" 19 | "context" 20 | "crypto/tls" 21 | "errors" 22 | "io" 23 | "net" 24 | "sync" 25 | ) 26 | 27 | var ( 28 | Closed = errors.New("connection closed") 29 | AlreadyClosed = errors.New("connection already closed") 30 | ) 31 | 32 | const ( 33 | MaxMessages = 1024 34 | MaxCoalesceSize = 16 * 1024 // TODO: What's a good value for this? 35 | ) 36 | 37 | type Conn struct { 38 | conn net.Conn 39 | closed chan struct{} 40 | messages chan Sender 41 | err error 42 | recv Receiver 43 | writer *bufio.Writer 44 | reader *bufio.Reader 45 | mu *sync.Mutex 46 | } 47 | 48 | type Receiver interface { 49 | Receive(reader io.Reader) error 50 | Closing(err error) 51 | } 52 | 53 | type Sender interface { 54 | Send(writer io.Writer) error 55 | } 56 | 57 | type SenderFunc func(writer io.Writer) error 58 | 59 | func (s SenderFunc) Send(writer io.Writer) error { 60 | return s(writer) 61 | } 62 | 63 | // Connect creates a new connection to a server specified by the endpoint using TLS if specified 64 | func Connect(ctx context.Context, endpoint Endpoint, recv Receiver) (c *Conn, err error) { 65 | var dialer net.Dialer 66 | addr, err := LookupEndpoint(endpoint) 67 | if err != nil { 68 | return nil, err 69 | } 70 | conn, err := dialer.DialContext(ctx, "tcp", addr) 71 | if err != nil { 72 | return nil, err 73 | } 74 | 75 | defer func() { 76 | if err != nil && conn != nil { 77 | _ = conn.Close() 78 | } 79 | }() 80 | 81 | if endpoint.TLSConfig() != nil { 82 | tlsConn := tls.Client(conn, endpoint.TLSConfig()) 83 | if err = tlsConn.Handshake(); err != nil { 84 | return nil, err 85 | } 86 | conn = tlsConn 87 | } 88 | 89 | c = NewConn(conn, recv) 90 | c.Start() 91 | return c, nil 92 | } 93 | 94 | func NewConn(conn net.Conn, recv Receiver) *Conn { 95 | return &Conn{ 96 | conn: conn, 97 | recv: recv, 98 | writer: bufio.NewWriterSize(conn, MaxCoalesceSize), 99 | reader: bufio.NewReader(conn), 100 | closed: make(chan struct{}), 101 | messages: make(chan Sender, MaxMessages), 102 | mu: &sync.Mutex{}, 103 | } 104 | } 105 | 106 | func (c *Conn) Start() { 107 | go c.read() 108 | go c.write() 109 | } 110 | 111 | func (c *Conn) read() { 112 | done := false 113 | for !done { 114 | done = c.checkErr(c.recv.Receive(c.reader)) 115 | } 116 | c.recv.Closing(c.Err()) 117 | } 118 | 119 | func (c *Conn) write() { 120 | done := false 121 | 122 | for !done { 123 | select { 124 | case sender := <-c.messages: 125 | done = c.checkErr(sender.Send(c.writer)) 126 | coalescing := true 127 | for coalescing && !done { 128 | select { 129 | case sender, coalescing = <-c.messages: 130 | done = c.checkErr(sender.Send(c.writer)) 131 | case <-c.closed: 132 | done = true 133 | default: 134 | coalescing = false 135 | } 136 | } 137 | case <-c.closed: 138 | done = true 139 | } 140 | 141 | if !done { // Check to avoid resetting `done` to false 142 | err := c.writer.Flush() 143 | done = c.checkErr(err) 144 | } 145 | } 146 | } 147 | 148 | func (c *Conn) WriteBytes(b []byte) error { 149 | return c.Write(SenderFunc(func(writer io.Writer) error { 150 | _, err := writer.Write(b) 151 | return err 152 | })) 153 | } 154 | 155 | func (c *Conn) Write(sender Sender) error { 156 | select { 157 | case c.messages <- sender: 158 | return nil 159 | case <-c.closed: 160 | return c.Err() 161 | } 162 | } 163 | 164 | func (c *Conn) checkErr(err error) bool { 165 | if err != nil { 166 | c.mu.Lock() 167 | if c.err == nil { 168 | c.err = err 169 | _ = c.conn.Close() 170 | close(c.closed) 171 | } 172 | c.mu.Unlock() 173 | return true 174 | } 175 | return false 176 | } 177 | 178 | func (c *Conn) Close() error { 179 | c.mu.Lock() 180 | defer c.mu.Unlock() 181 | if c.err != nil { 182 | return AlreadyClosed 183 | } 184 | close(c.closed) 185 | c.err = Closed 186 | return c.conn.Close() 187 | } 188 | 189 | func (c *Conn) Err() error { 190 | c.mu.Lock() 191 | err := c.err 192 | c.mu.Unlock() 193 | return err 194 | } 195 | 196 | func (c *Conn) IsClosed() chan struct{} { 197 | return c.closed 198 | } 199 | 200 | func (c *Conn) LocalAddr() net.Addr { 201 | return c.conn.LocalAddr() 202 | } 203 | 204 | func (c *Conn) RemoteAddr() net.Addr { 205 | return c.conn.RemoteAddr() 206 | } 207 | -------------------------------------------------------------------------------- /proxycore/conn_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxycore 16 | 17 | import ( 18 | "bytes" 19 | "context" 20 | "io" 21 | "math/rand" 22 | "net" 23 | "testing" 24 | "time" 25 | 26 | "github.com/stretchr/testify/assert" 27 | "github.com/stretchr/testify/require" 28 | ) 29 | 30 | func TestConnect(t *testing.T) { 31 | ctx := context.Background() 32 | 33 | listener, err := net.Listen("tcp", "127.0.0.1:8123") 34 | defer func(listener net.Listener) { 35 | _ = listener.Close() 36 | }(listener) 37 | require.NoError(t, err, "failed to listen") 38 | 39 | clientData := randomData(64 * 1024) 40 | serverData := randomData(64 * 1024) 41 | 42 | serverRecv := newTestRecv(clientData) 43 | servClosed := make(chan struct{}) 44 | 45 | go func() { 46 | c, err := listener.Accept() 47 | require.NoError(t, err, "failed to accept client connection") 48 | conn := NewConn(c, serverRecv) 49 | conn.Start() 50 | err = conn.WriteBytes(serverData) 51 | require.NoError(t, err, "failed to write bytes to client") 52 | select { 53 | case <-conn.IsClosed(): 54 | close(servClosed) 55 | } 56 | }() 57 | 58 | clientRecv := newTestRecv(serverData) 59 | clientConn, err := Connect(ctx, NewEndpoint("127.0.0.1:8123"), clientRecv) 60 | require.NoError(t, err, "failed to connect") 61 | 62 | err = clientConn.WriteBytes(clientData) 63 | require.NoError(t, err, "failed to write bytes to server") 64 | 65 | timer := time.NewTimer(2 * time.Second) 66 | 67 | wait := func(waitFor chan struct{}, msg string) { 68 | select { 69 | case <-waitFor: 70 | case <-timer.C: 71 | require.Fail(t, msg) 72 | } 73 | } 74 | 75 | wait(clientRecv.received, "timed out waiting to receive data from the server") 76 | wait(serverRecv.received, "timed out waiting to receive data from the client") 77 | 78 | _ = clientConn.Close() 79 | 80 | wait(clientConn.IsClosed(), "timed out waiting for client to close") 81 | wait(servClosed, "timed out waiting for server to close") 82 | 83 | wait(clientRecv.closing, "client closing method never called") 84 | wait(serverRecv.closing, "server closing method never called") 85 | } 86 | 87 | func TestConnect_Failures(t *testing.T) { 88 | var tests = []struct { 89 | endpoint Endpoint 90 | err string 91 | }{ 92 | {endpoint: NewEndpoint("127.0.0.1:8333"), err: "connection refused"}, 93 | {endpoint: &testEndpoint{addr: "127.0.0.1"}, err: "missing port in address"}, 94 | } 95 | ctx := context.Background() 96 | for _, tt := range tests { 97 | _, err := Connect(ctx, tt.endpoint, &testRecv{}) 98 | if assert.Error(t, err) { 99 | assert.Contains(t, err.Error(), tt.err) 100 | } 101 | } 102 | } 103 | 104 | type testRecv struct { 105 | expected []byte 106 | buf bytes.Buffer 107 | closing chan struct{} 108 | received chan struct{} 109 | } 110 | 111 | func newTestRecv(expected []byte) *testRecv { 112 | return &testRecv{ 113 | expected: expected, 114 | closing: make(chan struct{}), 115 | received: make(chan struct{}), 116 | } 117 | } 118 | 119 | func (t *testRecv) Receive(reader io.Reader) error { 120 | var buf [1024]byte 121 | n, err := reader.Read(buf[:]) 122 | if err != nil { 123 | return err 124 | } 125 | t.buf.Write(buf[:n]) 126 | if bytes.Equal(t.buf.Bytes(), t.expected) { 127 | close(t.received) 128 | } 129 | return nil 130 | } 131 | 132 | func (t *testRecv) Closing(_ error) { 133 | close(t.closing) 134 | } 135 | 136 | func randomData(n int) []byte { 137 | data := make([]byte, n) 138 | for i := 0; i < n; i++ { 139 | data[i] = 'a' + byte(rand.Intn(26)) 140 | } 141 | return data 142 | } 143 | -------------------------------------------------------------------------------- /proxycore/connpool.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxycore 16 | 17 | import ( 18 | "context" 19 | "errors" 20 | "fmt" 21 | "math" 22 | "sync" 23 | "time" 24 | 25 | "github.com/datastax/go-cassandra-native-protocol/primitive" 26 | "go.uber.org/zap" 27 | ) 28 | 29 | type connPoolConfig struct { 30 | Endpoint 31 | SessionConfig 32 | } 33 | 34 | type connPool struct { 35 | ctx context.Context 36 | config connPoolConfig 37 | logger *zap.Logger 38 | preparedCache PreparedCache 39 | cancel context.CancelFunc 40 | conns []*ClientConn 41 | connsMu *sync.RWMutex 42 | } 43 | 44 | // connectPool establishes a pool of connections to a given endpoint within a downstream cluster. These connection pools will 45 | // be used to proxy requests from the client to the cluster. 46 | func connectPool(ctx context.Context, config connPoolConfig) (*connPool, error) { 47 | ctx, cancel := context.WithCancel(ctx) 48 | 49 | pool := &connPool{ 50 | ctx: ctx, 51 | config: config, 52 | logger: GetOrCreateNopLogger(config.Logger), 53 | preparedCache: config.PreparedCache, 54 | cancel: cancel, 55 | conns: make([]*ClientConn, config.NumConns), 56 | connsMu: &sync.RWMutex{}, 57 | } 58 | 59 | errs := make([]error, config.NumConns) 60 | wg := sync.WaitGroup{} 61 | wg.Add(config.NumConns) 62 | 63 | for i := 0; i < config.NumConns; i++ { 64 | go func(idx int) { 65 | pool.conns[idx], errs[idx] = pool.connect() 66 | wg.Done() 67 | }(i) 68 | } 69 | 70 | wg.Wait() 71 | 72 | for _, err := range errs { 73 | if err != nil { 74 | pool.logger.Error("unable to connect pool", zap.Stringer("endpoint", config.Endpoint), zap.Error(err)) 75 | if isCriticalErr(err) { 76 | return nil, err 77 | } 78 | } 79 | } 80 | 81 | for i := 0; i < config.NumConns; i++ { 82 | go pool.stayConnected(i) 83 | } 84 | 85 | return pool, nil 86 | } 87 | 88 | func connectPoolNoFail(ctx context.Context, config connPoolConfig) *connPool { 89 | ctx, cancel := context.WithCancel(ctx) 90 | 91 | pool := &connPool{ 92 | ctx: ctx, 93 | config: config, 94 | logger: GetOrCreateNopLogger(config.Logger), 95 | cancel: cancel, 96 | conns: make([]*ClientConn, config.NumConns), 97 | connsMu: &sync.RWMutex{}, 98 | } 99 | 100 | for i := 0; i < config.NumConns; i++ { 101 | go pool.stayConnected(i) 102 | } 103 | 104 | return pool 105 | } 106 | 107 | func (p *connPool) leastBusyConn() *ClientConn { 108 | p.connsMu.RLock() 109 | defer p.connsMu.RUnlock() 110 | count := len(p.conns) 111 | if count == 0 { 112 | return nil 113 | } else if count == 1 { 114 | return p.conns[0] 115 | } else { 116 | idx := 0 117 | min := int32(math.MaxInt32) 118 | for i, conn := range p.conns { 119 | if conn != nil { 120 | inflight := conn.Inflight() 121 | if inflight < min { 122 | idx = i 123 | min = inflight 124 | } 125 | } 126 | } 127 | return p.conns[idx] 128 | } 129 | } 130 | 131 | func (p *connPool) connect() (conn *ClientConn, err error) { 132 | p.logger.Debug("creating pooled connection", 133 | zap.Stringer("endpoint", p.config.Endpoint), 134 | zap.Stringer("connect timeout", p.config.ConnectTimeout)) 135 | ctx, cancel := context.WithTimeout(p.ctx, p.config.ConnectTimeout) 136 | defer cancel() 137 | conn, err = ConnectClient(ctx, p.config.Endpoint, ClientConnConfig{ 138 | PreparedCache: p.preparedCache, 139 | Logger: p.logger}) 140 | if err != nil { 141 | return nil, err 142 | } 143 | 144 | defer func() { 145 | if err != nil && conn != nil { 146 | _ = conn.Close() 147 | } 148 | }() 149 | 150 | var startupKeysAndValues []string 151 | if p.config.Compression != "" { 152 | startupKeysAndValues = []string{"COMPRESSION", p.config.Compression} 153 | } 154 | 155 | var version primitive.ProtocolVersion 156 | version, err = conn.Handshake(ctx, p.config.Version, p.config.Auth, startupKeysAndValues...) 157 | if err != nil { 158 | if errors.Is(err, context.DeadlineExceeded) { 159 | return nil, fmt.Errorf("handshake took longer than %s to complete", p.config.ConnectTimeout) 160 | } 161 | return nil, err 162 | } 163 | if version != p.config.Version { 164 | p.logger.Error("protocol version not support", zap.Stringer("wanted", p.config.Version), zap.Stringer("got", version)) 165 | return nil, ProtocolNotSupported 166 | } 167 | 168 | if len(p.config.Keyspace) != 0 { 169 | err = conn.SetKeyspace(ctx, p.config.Version, p.config.Keyspace) 170 | if err != nil { 171 | return nil, err 172 | } 173 | } 174 | 175 | go conn.Heartbeats(p.config.ConnectTimeout, p.config.Version, p.config.HeartBeatInterval, p.config.IdleTimeout, p.logger) 176 | return conn, nil 177 | } 178 | 179 | // stayConnected will attempt to reestablish a disconnected (`connection == nil`) connection within the pool. Reconnect attempts 180 | // will be made at intervals defined by the ReconnectPolicy. 181 | func (p *connPool) stayConnected(idx int) { 182 | conn := p.conns[idx] 183 | 184 | connectTimer := time.NewTimer(0) 185 | reconnectPolicy := p.config.ReconnectPolicy.Clone() 186 | pendingConnect := true 187 | 188 | done := false 189 | for !done { 190 | if conn == nil { 191 | if !pendingConnect { 192 | delay := reconnectPolicy.NextDelay() 193 | p.logger.Info("pool connection attempting to reconnect after delay", 194 | zap.Stringer("endpoint", p.config.Endpoint), zap.Duration("delay", delay)) 195 | connectTimer = time.NewTimer(reconnectPolicy.NextDelay()) 196 | pendingConnect = true 197 | } else { 198 | select { 199 | case <-p.ctx.Done(): 200 | done = true 201 | case <-connectTimer.C: 202 | c, err := p.connect() 203 | if err != nil { 204 | p.logger.Error("pool connection failed to connect", 205 | zap.Stringer("endpoint", p.config.Endpoint), zap.Error(err)) 206 | } else { 207 | p.connsMu.Lock() 208 | conn, p.conns[idx] = c, c 209 | p.connsMu.Unlock() 210 | reconnectPolicy.Reset() 211 | } 212 | pendingConnect = false 213 | } 214 | } 215 | } else { 216 | select { 217 | case <-p.ctx.Done(): 218 | done = true 219 | _ = conn.Close() 220 | case <-conn.IsClosed(): 221 | p.logger.Info("pool connection closed", zap.Stringer("endpoint", p.config.Endpoint), zap.Error(conn.Err())) 222 | p.connsMu.Lock() 223 | conn, p.conns[idx] = nil, nil 224 | p.connsMu.Unlock() 225 | pendingConnect = false 226 | } 227 | } 228 | } 229 | } 230 | -------------------------------------------------------------------------------- /proxycore/endpoint.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxycore 16 | 17 | import ( 18 | "context" 19 | "crypto/tls" 20 | "errors" 21 | "fmt" 22 | "math/rand" 23 | "net" 24 | "strconv" 25 | ) 26 | 27 | var IgnoreEndpoint = errors.New("ignore endpoint") 28 | 29 | type Endpoint interface { 30 | fmt.Stringer 31 | Addr() string 32 | IsResolved() bool 33 | TLSConfig() *tls.Config 34 | Key() string 35 | } 36 | 37 | type defaultEndpoint struct { 38 | addr string 39 | tlsConfig *tls.Config 40 | } 41 | 42 | func (e defaultEndpoint) String() string { 43 | return e.Key() 44 | } 45 | 46 | func (e defaultEndpoint) Key() string { 47 | return e.addr 48 | } 49 | 50 | func (e defaultEndpoint) IsResolved() bool { 51 | return true 52 | } 53 | 54 | func (e defaultEndpoint) Addr() string { 55 | return e.addr 56 | } 57 | 58 | func (e defaultEndpoint) TLSConfig() *tls.Config { 59 | return e.tlsConfig 60 | } 61 | 62 | type EndpointResolver interface { 63 | Resolve(ctx context.Context) ([]Endpoint, error) 64 | NewEndpoint(row Row) (Endpoint, error) 65 | } 66 | 67 | type defaultEndpointResolver struct { 68 | contactPoints []string 69 | defaultPort string 70 | } 71 | 72 | func NewEndpoint(addr string) Endpoint { 73 | return &defaultEndpoint{addr: addr} 74 | } 75 | 76 | func NewEndpointTLS(addr string, cfg *tls.Config) Endpoint { 77 | return &defaultEndpoint{addr, cfg} 78 | } 79 | 80 | func NewResolver(contactPoints ...string) EndpointResolver { 81 | return NewResolverWithDefaultPort(contactPoints, 9042) 82 | } 83 | 84 | func NewResolverWithDefaultPort(contactPoints []string, defaultPort int) EndpointResolver { 85 | return &defaultEndpointResolver{ 86 | contactPoints: contactPoints, 87 | defaultPort: strconv.Itoa(defaultPort), 88 | } 89 | } 90 | 91 | func (r *defaultEndpointResolver) Resolve(ctx context.Context) ([]Endpoint, error) { 92 | var endpoints []Endpoint 93 | var resolver net.Resolver 94 | for _, cp := range r.contactPoints { 95 | host, port, err := net.SplitHostPort(cp) 96 | if err != nil { 97 | host = cp 98 | } 99 | addrs, err := resolver.LookupHost(ctx, host) 100 | if err != nil { 101 | return nil, fmt.Errorf("unable to resolve contact point %s: %v", cp, err) 102 | } 103 | if len(port) == 0 { 104 | port = r.defaultPort 105 | } 106 | for _, addr := range addrs { 107 | endpoints = append(endpoints, &defaultEndpoint{ 108 | addr: net.JoinHostPort(addr, port), 109 | }) 110 | } 111 | } 112 | return endpoints, nil 113 | } 114 | 115 | func (r *defaultEndpointResolver) NewEndpoint(row Row) (Endpoint, error) { 116 | peer, err := row.ByName("peer") 117 | if err != nil && !errors.Is(err, ColumnNameNotFound) { 118 | return nil, err 119 | } 120 | rpcAddress, err := row.InetByName("rpc_address") 121 | if err != nil { 122 | return nil, fmt.Errorf("ignoring host because its `rpc_address` is not set or is invalid: %w", err) 123 | } 124 | 125 | addr := rpcAddress 126 | if addr.Equal(net.IPv4zero) || addr.Equal(net.IPv6zero) { 127 | var ok bool 128 | if addr, ok = peer.(net.IP); !ok { 129 | return nil, errors.New("ignoring host because its `peer` is not set or is invalid") 130 | } 131 | } 132 | 133 | return &defaultEndpoint{ 134 | addr: net.JoinHostPort(addr.String(), r.defaultPort), 135 | }, nil 136 | } 137 | 138 | func LookupEndpoint(endpoint Endpoint) (string, error) { 139 | if endpoint.IsResolved() { 140 | return endpoint.Addr(), nil 141 | } else { 142 | host, port, err := net.SplitHostPort(endpoint.Addr()) 143 | if err != nil { 144 | return "", err 145 | } 146 | addrs, err := net.LookupHost(host) 147 | if err != nil { 148 | return "", err 149 | } 150 | addr := addrs[rand.Intn(len(addrs))] 151 | if len(port) > 0 { 152 | addr = net.JoinHostPort(addr, port) 153 | } 154 | return addr, nil 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /proxycore/endpoint_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxycore 16 | 17 | import ( 18 | "crypto/tls" 19 | "net" 20 | "testing" 21 | 22 | "github.com/datastax/cql-proxy/codecs" 23 | "github.com/datastax/go-cassandra-native-protocol/datatype" 24 | "github.com/datastax/go-cassandra-native-protocol/message" 25 | "github.com/datastax/go-cassandra-native-protocol/primitive" 26 | "github.com/stretchr/testify/assert" 27 | "github.com/stretchr/testify/require" 28 | ) 29 | 30 | func TestLookupEndpoint(t *testing.T) { 31 | addr, err := LookupEndpoint(&testEndpoint{addr: "localhost:9042"}) 32 | require.NoError(t, err, "unable to lookup endpoint") 33 | assert.True(t, addr == "127.0.0.1:9042" || addr == "[::1]:9042") 34 | 35 | addr, err = LookupEndpoint(&testEndpoint{addr: "127.0.0.1:9042", isResolved: true}) 36 | require.NoError(t, err, "unable to lookup endpoint") 37 | assert.Equal(t, "127.0.0.1:9042", addr) 38 | } 39 | 40 | func TestLookupEndpoint_Invalid(t *testing.T) { 41 | var tests = []struct { 42 | addr string 43 | err string 44 | }{ 45 | {"localhost", "missing port in address"}, 46 | {"test:1234", ""}, // Errors for DNS can vary per system 47 | } 48 | 49 | for _, tt := range tests { 50 | _, err := LookupEndpoint(&testEndpoint{addr: tt.addr}) 51 | if assert.Error(t, err) { 52 | assert.Contains(t, err.Error(), tt.err) 53 | } 54 | } 55 | } 56 | 57 | func TestEndpoint_NewEndpoint(t *testing.T) { 58 | resolver := NewResolver("127.0.0.1") 59 | 60 | const rpcAddr = "127.0.0.2" 61 | 62 | rpcAddrBytes, _ := codecs.EncodeType(datatype.Inet, primitive.ProtocolVersion4, net.ParseIP(rpcAddr)) 63 | 64 | rs := NewResultSet(&message.RowsResult{ 65 | Metadata: &message.RowsMetadata{ 66 | ColumnCount: 1, 67 | Columns: []*message.ColumnMetadata{ 68 | { 69 | Keyspace: "system", 70 | Table: "peers", 71 | Name: "rpc_address", 72 | Index: 0, 73 | Type: datatype.Inet, 74 | }, 75 | }, 76 | }, 77 | Data: message.RowSet{ 78 | message.Row{rpcAddrBytes}, 79 | }, 80 | }, primitive.ProtocolVersion4) 81 | 82 | endpoint, err := resolver.NewEndpoint(rs.Row(0)) 83 | assert.NotNil(t, endpoint) 84 | assert.Nil(t, err) 85 | assert.Contains(t, endpoint.Key(), rpcAddr) 86 | } 87 | 88 | func TestEndpoint_NewEndpointUnknownRPCAddress(t *testing.T) { 89 | resolver := NewResolver("127.0.0.1") 90 | 91 | const rpcAddr = "0.0.0.0" 92 | rpcAddrBytes, _ := codecs.EncodeType(datatype.Inet, primitive.ProtocolVersion4, net.ParseIP(rpcAddr)) 93 | 94 | const peer = "127.0.0.2" 95 | peerBytes, _ := codecs.EncodeType(datatype.Inet, primitive.ProtocolVersion4, net.ParseIP(peer)) 96 | 97 | rs := NewResultSet(&message.RowsResult{ 98 | Metadata: &message.RowsMetadata{ 99 | ColumnCount: 1, 100 | Columns: []*message.ColumnMetadata{ 101 | { 102 | Keyspace: "system", 103 | Table: "peers", 104 | Name: "peer", 105 | Index: 0, 106 | Type: datatype.Inet, 107 | }, 108 | { 109 | Keyspace: "system", 110 | Table: "peers", 111 | Name: "rpc_address", 112 | Index: 1, 113 | Type: datatype.Inet, 114 | }, 115 | }, 116 | }, 117 | Data: message.RowSet{ 118 | message.Row{peerBytes, rpcAddrBytes}, 119 | }, 120 | }, primitive.ProtocolVersion4) 121 | 122 | endpoint, err := resolver.NewEndpoint(rs.Row(0)) 123 | assert.NotNil(t, endpoint) 124 | assert.Nil(t, err) 125 | assert.Contains(t, endpoint.Key(), peer) 126 | } 127 | 128 | func TestEndpoint_NewEndpointInvalidRPCAddress(t *testing.T) { 129 | resolver := NewResolver("127.0.0.1") 130 | 131 | const peer = "127.0.0.2" 132 | peerBytes, _ := codecs.EncodeType(datatype.Inet, primitive.ProtocolVersion4, net.ParseIP(peer)) 133 | 134 | rs := NewResultSet(&message.RowsResult{ 135 | Metadata: &message.RowsMetadata{ 136 | ColumnCount: 1, 137 | Columns: []*message.ColumnMetadata{ 138 | { 139 | Keyspace: "system", 140 | Table: "peers", 141 | Name: "peer", 142 | Index: 0, 143 | Type: datatype.Inet, 144 | }, 145 | { 146 | Keyspace: "system", 147 | Table: "peers", 148 | Name: "rpc_address", 149 | Index: 1, 150 | Type: datatype.Inet, 151 | }, 152 | }, 153 | }, 154 | Data: message.RowSet{ 155 | message.Row{peerBytes, nil}, // Null rpc_address 156 | }, 157 | }, primitive.ProtocolVersion4) 158 | 159 | endpoint, err := resolver.NewEndpoint(rs.Row(0)) 160 | assert.Nil(t, endpoint) 161 | assert.Error(t, err, "ignoring host because its `rpc_address` is not set or is invalid") 162 | } 163 | 164 | func TestEndpoint_NewEndpointInvalidPeer(t *testing.T) { 165 | resolver := NewResolver("127.0.0.1") 166 | 167 | const rpcAddr = "0.0.0.0" 168 | rpcAddrBytes, _ := codecs.EncodeType(datatype.Inet, primitive.ProtocolVersion4, net.ParseIP(rpcAddr)) 169 | 170 | rs := NewResultSet(&message.RowsResult{ 171 | Metadata: &message.RowsMetadata{ 172 | ColumnCount: 1, 173 | Columns: []*message.ColumnMetadata{ 174 | { 175 | Keyspace: "system", 176 | Table: "peers", 177 | Name: "peer", 178 | Index: 0, 179 | Type: datatype.Inet, 180 | }, 181 | { 182 | Keyspace: "system", 183 | Table: "peers", 184 | Name: "rpc_address", 185 | Index: 1, 186 | Type: datatype.Inet, 187 | }, 188 | }, 189 | }, 190 | Data: message.RowSet{ 191 | message.Row{nil, rpcAddrBytes}, // Null peer 192 | }, 193 | }, primitive.ProtocolVersion4) 194 | 195 | endpoint, err := resolver.NewEndpoint(rs.Row(0)) 196 | assert.Nil(t, endpoint) 197 | assert.Error(t, err, "ignoring host because its `peer` is not set or is invalid") 198 | } 199 | 200 | type testEndpoint struct { 201 | addr string 202 | isResolved bool 203 | } 204 | 205 | func (t testEndpoint) String() string { 206 | return t.addr 207 | } 208 | 209 | func (t testEndpoint) Addr() string { 210 | return t.addr 211 | } 212 | 213 | func (t testEndpoint) IsResolved() bool { 214 | return t.isResolved 215 | } 216 | 217 | func (t testEndpoint) TLSConfig() *tls.Config { 218 | return nil 219 | } 220 | 221 | func (t testEndpoint) Key() string { 222 | return t.addr 223 | } 224 | -------------------------------------------------------------------------------- /proxycore/errors.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxycore 16 | 17 | import ( 18 | "errors" 19 | "fmt" 20 | "io" 21 | "strings" 22 | "syscall" 23 | 24 | "github.com/datastax/go-cassandra-native-protocol/message" 25 | ) 26 | 27 | var ( 28 | StreamsExhausted = errors.New("streams exhausted") 29 | AuthExpected = errors.New("authentication required, but no authenticator provided") 30 | ProtocolNotSupported = errors.New("required protocol version is not supported") 31 | ) 32 | 33 | type UnexpectedResponse struct { 34 | Expected []string 35 | Received string 36 | } 37 | 38 | func (e *UnexpectedResponse) Error() string { 39 | return fmt.Sprintf("expected %s response(s), got %s", strings.Join(e.Expected, ", "), e.Received) 40 | } 41 | 42 | type CqlError struct { 43 | Message message.Error 44 | } 45 | 46 | func (e CqlError) Error() string { 47 | return fmt.Sprintf("cql error: %v", e.Message) 48 | } 49 | 50 | func isCriticalErr(err error) bool { 51 | // Anything that's not a temporary unavailability 52 | return !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNREFUSED) && !errors.Is(err, syscall.ECONNRESET) 53 | } 54 | -------------------------------------------------------------------------------- /proxycore/host.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxycore 16 | 17 | type Host struct { 18 | Endpoint 19 | DC string 20 | } 21 | 22 | func NewHostFromRow(endpoint Endpoint, row Row) (*Host, error) { 23 | dc, err := row.StringByName("data_center") 24 | if err != nil { 25 | return nil, err 26 | } 27 | return &Host{endpoint, dc}, nil 28 | } 29 | 30 | func (h *Host) Key() string { 31 | return h.Endpoint.Key() 32 | } 33 | 34 | func (h *Host) String() string { 35 | return h.Endpoint.String() 36 | } 37 | -------------------------------------------------------------------------------- /proxycore/lb.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxycore 16 | 17 | import ( 18 | "sync" 19 | "sync/atomic" 20 | ) 21 | 22 | type QueryPlan interface { 23 | Next() *Host 24 | } 25 | 26 | type LoadBalancer interface { 27 | ClusterListener 28 | NewQueryPlan() QueryPlan 29 | } 30 | 31 | func NewRoundRobinLoadBalancer() LoadBalancer { 32 | lb := &roundRobinLoadBalancer{ 33 | mu: &sync.Mutex{}, 34 | } 35 | lb.hosts.Store(make([]*Host, 0)) 36 | return lb 37 | } 38 | 39 | type roundRobinLoadBalancer struct { 40 | hosts atomic.Value 41 | index uint32 42 | mu *sync.Mutex 43 | } 44 | 45 | func (l *roundRobinLoadBalancer) OnEvent(event Event) { 46 | l.mu.Lock() 47 | defer l.mu.Unlock() 48 | 49 | switch evt := event.(type) { 50 | case *BootstrapEvent: 51 | l.hosts.Store(evt.Hosts) 52 | case *AddEvent: 53 | l.hosts.Store(append(l.copy(), evt.Host)) 54 | case *RemoveEvent: 55 | cpy := l.copy() 56 | for i, h := range cpy { 57 | if h.Key() == evt.Host.Key() { 58 | l.hosts.Store(append(cpy[:i], cpy[i+1:]...)) 59 | break 60 | } 61 | } 62 | } 63 | } 64 | 65 | func (l *roundRobinLoadBalancer) copy() []*Host { 66 | hosts := l.hosts.Load().([]*Host) 67 | cpy := make([]*Host, len(hosts)) 68 | copy(cpy, hosts) 69 | return cpy 70 | } 71 | 72 | func (l *roundRobinLoadBalancer) NewQueryPlan() QueryPlan { 73 | return &roundRobinQueryPlan{ 74 | hosts: l.hosts.Load().([]*Host), 75 | offset: atomic.AddUint32(&l.index, 1) - 1, 76 | index: 0, 77 | } 78 | } 79 | 80 | type roundRobinQueryPlan struct { 81 | hosts []*Host 82 | offset uint32 83 | index uint32 84 | } 85 | 86 | func (p *roundRobinQueryPlan) Next() *Host { 87 | l := uint32(len(p.hosts)) 88 | if p.index >= l { 89 | return nil 90 | } 91 | host := p.hosts[(p.offset+p.index)%l] 92 | p.index++ 93 | return host 94 | } 95 | -------------------------------------------------------------------------------- /proxycore/lb_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxycore 16 | 17 | import ( 18 | "testing" 19 | 20 | "github.com/stretchr/testify/assert" 21 | ) 22 | 23 | func TestRoundRobinLoadBalancer_NewQueryPlan(t *testing.T) { 24 | lb := NewRoundRobinLoadBalancer() 25 | 26 | qp := lb.NewQueryPlan() 27 | assert.Nil(t, qp.Next()) 28 | 29 | newHost := func(addr string) *Host { 30 | return &Host{Endpoint: &defaultEndpoint{addr: addr}} 31 | } 32 | 33 | lb.OnEvent(&BootstrapEvent{Hosts: []*Host{newHost("127.0.0.1"), newHost("127.0.0.2"), newHost("127.0.0.3")}}) 34 | qp = lb.NewQueryPlan() 35 | assert.Equal(t, newHost("127.0.0.2"), qp.Next()) 36 | assert.Equal(t, newHost("127.0.0.3"), qp.Next()) 37 | assert.Equal(t, newHost("127.0.0.1"), qp.Next()) 38 | assert.Nil(t, qp.Next()) 39 | 40 | lb.OnEvent(&AddEvent{Host: newHost("127.0.0.4")}) 41 | 42 | qp = lb.NewQueryPlan() 43 | assert.Equal(t, newHost("127.0.0.3"), qp.Next()) 44 | assert.Equal(t, newHost("127.0.0.4"), qp.Next()) 45 | assert.Equal(t, newHost("127.0.0.1"), qp.Next()) 46 | assert.Equal(t, newHost("127.0.0.2"), qp.Next()) 47 | assert.Nil(t, qp.Next()) 48 | 49 | lb.OnEvent(&RemoveEvent{Host: newHost("127.0.0.4")}) 50 | 51 | qp = lb.NewQueryPlan() 52 | assert.Equal(t, newHost("127.0.0.1"), qp.Next()) 53 | assert.Equal(t, newHost("127.0.0.2"), qp.Next()) 54 | assert.Equal(t, newHost("127.0.0.3"), qp.Next()) 55 | assert.Nil(t, qp.Next()) 56 | 57 | lb.OnEvent(&RemoveEvent{Host: newHost("127.0.0.3")}) 58 | 59 | qp = lb.NewQueryPlan() 60 | assert.Equal(t, newHost("127.0.0.1"), qp.Next()) 61 | assert.Equal(t, newHost("127.0.0.2"), qp.Next()) 62 | assert.Nil(t, qp.Next()) 63 | 64 | lb.OnEvent(&RemoveEvent{Host: newHost("127.0.0.2")}) 65 | 66 | qp = lb.NewQueryPlan() 67 | assert.Equal(t, newHost("127.0.0.1"), qp.Next()) 68 | assert.Nil(t, qp.Next()) 69 | 70 | lb.OnEvent(&RemoveEvent{Host: newHost("127.0.0.1")}) 71 | 72 | qp = lb.NewQueryPlan() 73 | assert.Nil(t, qp.Next()) 74 | } 75 | -------------------------------------------------------------------------------- /proxycore/log.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxycore 16 | 17 | import "go.uber.org/zap" 18 | 19 | func GetOrCreateNopLogger(logger *zap.Logger) *zap.Logger { 20 | if logger == nil { 21 | return zap.NewNop() 22 | } 23 | return logger 24 | } 25 | -------------------------------------------------------------------------------- /proxycore/reconnpolicy.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxycore 16 | 17 | import ( 18 | "math/bits" 19 | "math/rand" 20 | "time" 21 | ) 22 | 23 | const ( 24 | defaultBaseDelay = 2 * time.Second 25 | defaultMaxDelay = 10 * time.Minute 26 | ) 27 | 28 | type ReconnectPolicy interface { 29 | NextDelay() time.Duration 30 | Reset() 31 | Clone() ReconnectPolicy 32 | } 33 | 34 | type defaultReconnectPolicy struct { 35 | attempts int 36 | maxAttempts int 37 | baseDelay time.Duration 38 | maxDelay time.Duration 39 | } 40 | 41 | func NewReconnectPolicy() ReconnectPolicy { 42 | return NewReconnectPolicyWithDelays(defaultBaseDelay, defaultMaxDelay) 43 | } 44 | 45 | func NewReconnectPolicyWithDelays(baseDelay, maxDelay time.Duration) ReconnectPolicy { 46 | return &defaultReconnectPolicy{ 47 | attempts: 0, 48 | maxAttempts: calcMaxAttempts(baseDelay), 49 | baseDelay: baseDelay, 50 | maxDelay: maxDelay, 51 | } 52 | } 53 | 54 | func (d *defaultReconnectPolicy) NextDelay() time.Duration { 55 | if d.attempts >= d.maxAttempts { 56 | return d.maxDelay 57 | } 58 | jitter := time.Duration(rand.Intn(30)+85) * time.Millisecond 59 | exp := time.Millisecond << d.attempts 60 | d.attempts++ 61 | delay := d.baseDelay + exp + jitter 62 | if delay > d.maxDelay { 63 | delay = d.maxDelay 64 | } 65 | return delay 66 | } 67 | 68 | func (d *defaultReconnectPolicy) Reset() { 69 | d.attempts = 0 70 | } 71 | 72 | func (d defaultReconnectPolicy) Clone() ReconnectPolicy { 73 | return NewReconnectPolicyWithDelays(d.baseDelay, d.maxDelay) 74 | } 75 | 76 | func calcMaxAttempts(baseDelay time.Duration) int { 77 | return 63 - bits.LeadingZeros64(uint64(baseDelay)) 78 | } 79 | -------------------------------------------------------------------------------- /proxycore/reconnpolicy_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxycore 16 | 17 | import ( 18 | "github.com/stretchr/testify/assert" 19 | "math" 20 | "testing" 21 | "time" 22 | ) 23 | 24 | func TestDefaultReconnectPolicy(t *testing.T) { 25 | var tests = []struct { 26 | base time.Duration 27 | max time.Duration 28 | policy ReconnectPolicy 29 | }{ 30 | {defaultBaseDelay, defaultMaxDelay, NewReconnectPolicy()}, 31 | {time.Second, 2 * time.Minute, NewReconnectPolicyWithDelays(time.Second, 2*time.Minute)}, 32 | {200 * time.Millisecond, time.Hour, NewReconnectPolicyWithDelays(200*time.Millisecond, time.Hour)}, 33 | {time.Millisecond, 24 * time.Hour, NewReconnectPolicyWithDelays(time.Millisecond, 24*time.Hour)}, 34 | } 35 | for _, tt := range tests { 36 | verifyBaseWithJitter := func(policy ReconnectPolicy) { 37 | assert.InDelta(t, tt.base, policy.NextDelay(), float64((85+30)*time.Millisecond)) // include jitter 38 | } 39 | 40 | iterations := int(math.Ceil(math.Log2(float64((tt.max - tt.base) / time.Millisecond)))) 41 | verifyBaseWithJitter(tt.policy) 42 | 43 | for i := 0; i < iterations-1; i++ { 44 | tt.policy.NextDelay() 45 | } 46 | assert.Equal(t, tt.max, tt.policy.NextDelay()) 47 | assert.Equal(t, tt.max, tt.policy.NextDelay()) // after max it should stay max 48 | 49 | verifyBaseWithJitter(tt.policy.Clone()) // cloned policy should be reset 50 | 51 | tt.policy.Reset() 52 | verifyBaseWithJitter(tt.policy) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /proxycore/requests.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxycore 16 | 17 | import ( 18 | "sync" 19 | 20 | "github.com/datastax/go-cassandra-native-protocol/frame" 21 | ) 22 | 23 | // Request represents the data frame and lifecycle of a CQL native protocol request. 24 | type Request interface { 25 | // Frame returns the frame to be executed as part of the request. 26 | // This must be idempotent. 27 | Frame() interface{} 28 | 29 | // IsPrepareRequest returns whether the request's frame is a `PREPARE` request. This is used to determine if the 30 | // prepared cache should be updated. 31 | IsPrepareRequest() bool 32 | 33 | // Execute is called when a request need to be retried. 34 | // This is currently only called for executing prepared requests (i.e. `EXECUTE` request frames). If `EXECUTE` 35 | // request frames are not expected then the implementation should `panic()`. 36 | // 37 | // If `next` is false then the request must be retried on the current node; otherwise, it should be retried on 38 | // another node which is usually then next node in a query plan. 39 | Execute(next bool) 40 | 41 | // OnClose is called when the underlying connection is closed. 42 | // No assumptions should be made about whether the request has been successfully sent; it is possible that 43 | // the request has been fully sent and no response was received before 44 | OnClose(err error) 45 | 46 | // OnResult is called when a response frame has been sent back from the connection. 47 | OnResult(raw *frame.RawFrame) 48 | } 49 | 50 | type pendingRequests struct { 51 | pending *sync.Map 52 | streams chan int16 53 | } 54 | 55 | func newPendingRequests(maxStreams int16) *pendingRequests { 56 | streams := make(chan int16, maxStreams) 57 | for i := int16(0); i < maxStreams; i++ { 58 | streams <- i 59 | } 60 | return &pendingRequests{ 61 | pending: &sync.Map{}, 62 | streams: streams, 63 | } 64 | } 65 | 66 | func (p *pendingRequests) store(request Request) int16 { 67 | select { 68 | case stream := <-p.streams: 69 | p.pending.Store(stream, request) 70 | return stream 71 | default: 72 | return -1 73 | } 74 | } 75 | 76 | func (p *pendingRequests) loadAndDelete(stream int16) Request { 77 | request, ok := p.pending.LoadAndDelete(stream) 78 | if ok { 79 | p.streams <- stream 80 | return request.(Request) 81 | } 82 | return nil 83 | } 84 | 85 | func (p *pendingRequests) closing(err error) { 86 | p.pending.Range(func(key, value interface{}) bool { 87 | request := value.(Request) 88 | request.OnClose(err) 89 | return true 90 | }) 91 | } 92 | -------------------------------------------------------------------------------- /proxycore/requests_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxycore 16 | 17 | import ( 18 | "io" 19 | "testing" 20 | 21 | "github.com/datastax/go-cassandra-native-protocol/frame" 22 | "github.com/stretchr/testify/assert" 23 | ) 24 | 25 | func TestPendingRequests(t *testing.T) { 26 | const max = 10 27 | 28 | p := newPendingRequests(max) 29 | 30 | errs := make([]error, 0) 31 | 32 | for i := int16(0); i < max; i++ { 33 | assert.Equal(t, i, p.store(&testPendingRequest{stream: i, errs: &errs})) 34 | } 35 | assert.Equal(t, int16(-1), p.store(&testPendingRequest{})) 36 | 37 | r := p.loadAndDelete(0).(*testPendingRequest) 38 | assert.Equal(t, int16(0), r.stream) 39 | 40 | r = p.loadAndDelete(9).(*testPendingRequest) 41 | assert.Equal(t, int16(9), r.stream) 42 | 43 | assert.Equal(t, int16(0), p.store(&testPendingRequest{stream: 0, errs: &errs})) 44 | assert.Equal(t, int16(9), p.store(&testPendingRequest{stream: 9, errs: &errs})) 45 | assert.Equal(t, int16(-1), p.store(&testPendingRequest{})) 46 | 47 | p.closing(io.EOF) 48 | 49 | assert.Equal(t, 10, len(errs)) 50 | 51 | for _, err := range errs { 52 | assert.ErrorAs(t, err, &io.EOF) 53 | } 54 | } 55 | 56 | type testPendingRequest struct { 57 | stream int16 58 | errs *[]error 59 | } 60 | 61 | func (t *testPendingRequest) Execute(_ bool) { 62 | panic("not implemented") 63 | } 64 | 65 | func (t *testPendingRequest) IsPrepareRequest() bool { 66 | panic("not implemented") 67 | } 68 | 69 | func (t *testPendingRequest) Frame() interface{} { 70 | panic("not implemented") 71 | } 72 | 73 | func (t *testPendingRequest) OnClose(err error) { 74 | *t.errs = append(*t.errs, err) 75 | } 76 | 77 | func (t testPendingRequest) OnResult(_ *frame.RawFrame) { 78 | panic("not implemented") 79 | } 80 | -------------------------------------------------------------------------------- /proxycore/resultset.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxycore 16 | 17 | import ( 18 | "errors" 19 | "fmt" 20 | "net" 21 | 22 | "github.com/datastax/cql-proxy/codecs" 23 | "github.com/datastax/go-cassandra-native-protocol/message" 24 | "github.com/datastax/go-cassandra-native-protocol/primitive" 25 | ) 26 | 27 | var ( 28 | ColumnNameNotFound = errors.New("column name not found") 29 | ColumnIsNull = errors.New("column is null") 30 | ) 31 | 32 | type ResultSet struct { 33 | columnIndexes map[string]int 34 | result *message.RowsResult 35 | version primitive.ProtocolVersion 36 | } 37 | 38 | type Row struct { 39 | resultSet *ResultSet 40 | row message.Row 41 | } 42 | 43 | func NewResultSet(rows *message.RowsResult, version primitive.ProtocolVersion) *ResultSet { 44 | columnIndexes := make(map[string]int) 45 | for i, column := range rows.Metadata.Columns { 46 | columnIndexes[column.Name] = i 47 | } 48 | return &ResultSet{ 49 | columnIndexes: columnIndexes, 50 | result: rows, 51 | version: version, 52 | } 53 | } 54 | 55 | func (rs *ResultSet) Row(i int) Row { 56 | return Row{ 57 | rs, 58 | rs.result.Data[i]} 59 | } 60 | 61 | func (rs ResultSet) RowCount() int { 62 | return len(rs.result.Data) 63 | } 64 | 65 | func (r Row) ByPos(i int) (interface{}, error) { 66 | val, err := codecs.DecodeType(r.resultSet.result.Metadata.Columns[i].Type, r.resultSet.version, r.row[i]) 67 | if err != nil { 68 | return nil, err 69 | } 70 | return val, nil 71 | } 72 | 73 | func (r Row) ByName(n string) (interface{}, error) { 74 | if i, ok := r.resultSet.columnIndexes[n]; !ok { 75 | return nil, ColumnNameNotFound 76 | } else { 77 | return r.ByPos(i) 78 | } 79 | } 80 | 81 | func (r Row) StringByName(n string) (string, error) { 82 | val, err := r.ByName(n) 83 | if err != nil { 84 | return "", err 85 | } 86 | if val == nil { 87 | return "", ColumnIsNull 88 | } else if s, ok := val.(string); !ok { 89 | return "", fmt.Errorf("'%s' is not a string", n) 90 | } else { 91 | return s, nil 92 | } 93 | } 94 | 95 | func (r Row) InetByName(n string) (net.IP, error) { 96 | val, err := r.ByName(n) 97 | if err != nil { 98 | return nil, err 99 | } 100 | if val == nil { 101 | return nil, ColumnIsNull 102 | } else if ip, ok := val.(net.IP); !ok { 103 | return nil, fmt.Errorf("'%s' is not an inet (or is null)", n) 104 | } else { 105 | return ip, nil 106 | } 107 | } 108 | 109 | func (r Row) UUIDByName(n string) (primitive.UUID, error) { 110 | val, err := r.ByName(n) 111 | if err != nil { 112 | return [16]byte{}, err 113 | } 114 | if val == nil { 115 | return [16]byte{}, ColumnIsNull 116 | } else if u, ok := val.(primitive.UUID); !ok { 117 | return [16]byte{}, fmt.Errorf("'%s' is not a uuid (or is null)", n) 118 | } else { 119 | return u, nil 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /proxycore/session.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxycore 16 | 17 | import ( 18 | "context" 19 | "errors" 20 | "sync" 21 | "time" 22 | 23 | "github.com/datastax/go-cassandra-native-protocol/frame" 24 | "github.com/datastax/go-cassandra-native-protocol/primitive" 25 | "go.uber.org/zap" 26 | ) 27 | 28 | var ( 29 | NoConnForHost = errors.New("no connection available for host") 30 | ) 31 | 32 | // PreparedEntry is an entry in the prepared cache. 33 | type PreparedEntry struct { 34 | PreparedFrame *frame.RawFrame 35 | } 36 | 37 | // PreparedCache a thread-safe cache for storing prepared queries. 38 | type PreparedCache interface { 39 | // Store add an entry to the cache. 40 | Store(id string, entry *PreparedEntry) 41 | // Load retrieves an entry from the cache. `ok` is true if the entry is present; otherwise it's false. 42 | Load(id string) (entry *PreparedEntry, ok bool) 43 | } 44 | 45 | type SessionConfig struct { 46 | ReconnectPolicy ReconnectPolicy 47 | NumConns int 48 | Keyspace string 49 | Version primitive.ProtocolVersion 50 | Auth Authenticator 51 | // PreparedCache a global cache share across sessions for storing previously prepared queries 52 | PreparedCache PreparedCache 53 | ConnectTimeout time.Duration 54 | HeartBeatInterval time.Duration 55 | IdleTimeout time.Duration 56 | Logger *zap.Logger 57 | Compression string 58 | } 59 | 60 | type Session struct { 61 | ctx context.Context 62 | config SessionConfig 63 | logger *zap.Logger 64 | pools sync.Map 65 | connected chan struct{} 66 | failed chan error 67 | } 68 | 69 | func ConnectSession(ctx context.Context, cluster *Cluster, config SessionConfig) (*Session, error) { 70 | session := &Session{ 71 | ctx: ctx, 72 | config: config, 73 | logger: GetOrCreateNopLogger(config.Logger), 74 | pools: sync.Map{}, 75 | connected: make(chan struct{}), 76 | failed: make(chan error, 1), 77 | } 78 | 79 | err := cluster.Listen(session) 80 | if err != nil { 81 | return nil, err 82 | } 83 | 84 | select { 85 | case <-ctx.Done(): 86 | return nil, ctx.Err() 87 | case <-session.connected: 88 | return session, nil 89 | case err = <-session.failed: 90 | return nil, err 91 | } 92 | } 93 | 94 | func (s *Session) Send(host *Host, request Request) error { 95 | conn := s.leastBusyConn(host) 96 | if conn == nil { 97 | return NoConnForHost 98 | } 99 | return conn.Send(request) 100 | } 101 | 102 | func (s *Session) leastBusyConn(host *Host) *ClientConn { 103 | if p, ok := s.pools.Load(host.Key()); ok { 104 | pool := p.(*connPool) 105 | return pool.leastBusyConn() 106 | } 107 | return nil 108 | } 109 | 110 | func (s *Session) OnEvent(event Event) { 111 | switch evt := event.(type) { 112 | case *BootstrapEvent: 113 | go func() { 114 | var wg sync.WaitGroup 115 | 116 | count := len(evt.Hosts) 117 | wg.Add(count) 118 | 119 | for _, host := range evt.Hosts { 120 | go func(host *Host) { 121 | pool, err := connectPool(s.ctx, connPoolConfig{ 122 | Endpoint: host.Endpoint, 123 | SessionConfig: s.config, 124 | }) 125 | if err != nil { 126 | select { 127 | case s.failed <- err: 128 | default: 129 | } 130 | } 131 | s.pools.Store(host.Key(), pool) 132 | wg.Done() 133 | }(host) 134 | } 135 | 136 | wg.Wait() 137 | 138 | close(s.connected) 139 | }() 140 | case *AddEvent: 141 | // There's no compute if absent for sync.Map, figure a better way to do this if the pool already exists. 142 | if pool, loaded := s.pools.LoadOrStore(evt.Host.Key(), connectPoolNoFail(s.ctx, connPoolConfig{ 143 | Endpoint: evt.Host.Endpoint, 144 | SessionConfig: s.config, 145 | })); loaded { 146 | p := pool.(*connPool) 147 | p.cancel() 148 | } 149 | case *RemoveEvent: 150 | if pool, ok := s.pools.LoadAndDelete(evt.Host.Key()); ok { 151 | p := pool.(*connPool) 152 | p.cancel() 153 | } 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /proxycore/session_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) DataStax, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package proxycore 16 | 17 | import ( 18 | "context" 19 | "net" 20 | "sync" 21 | "testing" 22 | "time" 23 | 24 | "github.com/datastax/cql-proxy/codecs" 25 | "github.com/datastax/go-cassandra-native-protocol/frame" 26 | "github.com/datastax/go-cassandra-native-protocol/message" 27 | "github.com/datastax/go-cassandra-native-protocol/primitive" 28 | "github.com/stretchr/testify/assert" 29 | "github.com/stretchr/testify/require" 30 | "go.uber.org/zap" 31 | ) 32 | 33 | func TestConnectSession(t *testing.T) { 34 | logger, _ := zap.NewDevelopment() 35 | 36 | ctx, cancel := context.WithCancel(context.Background()) 37 | defer cancel() 38 | 39 | const supported = primitive.ProtocolVersion4 40 | 41 | c := NewMockCluster(net.ParseIP("127.0.0.0"), 9042) 42 | 43 | err := c.Add(ctx, 1) 44 | require.NoError(t, err) 45 | 46 | err = c.Add(ctx, 2) 47 | require.NoError(t, err) 48 | 49 | err = c.Add(ctx, 3) 50 | require.NoError(t, err) 51 | 52 | cluster, err := ConnectCluster(ctx, ClusterConfig{ 53 | Version: supported, 54 | Resolver: NewResolver("127.0.0.1:9042"), 55 | ReconnectPolicy: NewReconnectPolicyWithDelays(200*time.Millisecond, time.Second), 56 | RefreshWindow: 100 * time.Millisecond, 57 | ConnectTimeout: 10 * time.Second, 58 | Logger: logger, 59 | HeartBeatInterval: 30 * time.Second, 60 | IdleTimeout: 60 * time.Second, 61 | }) 62 | require.NoError(t, err) 63 | 64 | session, err := ConnectSession(ctx, cluster, SessionConfig{ 65 | ReconnectPolicy: NewReconnectPolicyWithDelays(200*time.Millisecond, time.Second), 66 | NumConns: 2, 67 | Version: supported, 68 | ConnectTimeout: 10 * time.Second, 69 | HeartBeatInterval: 30 * time.Second, 70 | IdleTimeout: 60 * time.Second, 71 | }) 72 | require.NoError(t, err) 73 | 74 | newHost := func(addr string) *Host { 75 | return &Host{Endpoint: &defaultEndpoint{addr: addr}} 76 | } 77 | 78 | var wg sync.WaitGroup 79 | 80 | wg.Add(3) 81 | 82 | err = session.Send(newHost("127.0.0.1:9042"), &testSessionRequest{t: t, rpcAddr: "127.0.0.1", wg: &wg}) 83 | require.NoError(t, err) 84 | 85 | err = session.Send(newHost("127.0.0.2:9042"), &testSessionRequest{t: t, rpcAddr: "127.0.0.2", wg: &wg}) 86 | require.NoError(t, err) 87 | 88 | err = session.Send(newHost("127.0.0.3:9042"), &testSessionRequest{t: t, rpcAddr: "127.0.0.3", wg: &wg}) 89 | require.NoError(t, err) 90 | 91 | wg.Wait() 92 | 93 | err = c.Add(ctx, 4) 94 | require.NoError(t, err) 95 | 96 | available := waitUntil(10*time.Second, func() bool { 97 | return session.leastBusyConn(newHost("127.0.0.4:9042")) != nil 98 | }) 99 | require.True(t, available) 100 | 101 | wg.Add(1) 102 | 103 | err = session.Send(newHost("127.0.0.4:9042"), &testSessionRequest{t: t, rpcAddr: "127.0.0.4", wg: &wg}) 104 | require.NoError(t, err) 105 | 106 | wg.Wait() 107 | 108 | c.Remove(4) 109 | 110 | removed := waitUntil(10*time.Second, func() bool { 111 | return session.leastBusyConn(newHost("127.0.0.4:9042")) == nil 112 | }) 113 | require.True(t, removed) 114 | } 115 | 116 | type testSessionRequest struct { 117 | t *testing.T 118 | version primitive.ProtocolVersion 119 | rpcAddr string 120 | wg *sync.WaitGroup 121 | } 122 | 123 | func (r testSessionRequest) Execute(next bool) { 124 | panic("not implemented") 125 | } 126 | 127 | func (r testSessionRequest) Frame() interface{} { 128 | return frame.NewFrame(primitive.ProtocolVersion4, -1, &message.Query{ 129 | Query: "SELECT * FROM system.local", 130 | }) 131 | } 132 | 133 | func (r testSessionRequest) IsPrepareRequest() bool { 134 | return false 135 | } 136 | 137 | func (r testSessionRequest) OnClose(_ error) { 138 | require.Fail(r.t, "connection unexpectedly closed") 139 | } 140 | 141 | func (r testSessionRequest) OnResult(raw *frame.RawFrame) { 142 | frm, err := codecs.DefaultRawCodec.ConvertFromRawFrame(raw) 143 | require.NoError(r.t, err) 144 | 145 | switch msg := frm.Body.Message.(type) { 146 | case *message.RowsResult: 147 | rs := NewResultSet(msg, r.version) 148 | rpcAddr, err := rs.Row(0).ByName("rpc_address") 149 | require.NoError(r.t, err) 150 | assert.Equal(r.t, rpcAddr.(net.IP).String(), r.rpcAddr) 151 | default: 152 | require.Fail(r.t, "invalid message body") 153 | } 154 | 155 | r.wg.Done() 156 | } 157 | --------------------------------------------------------------------------------