├── .dockerignore ├── .editorconfig ├── .github ├── dependabot.yml └── workflows │ ├── go.yml │ └── image.yml ├── .gitignore ├── Dockerfile.server ├── LICENSE ├── Makefile ├── README.md ├── charts └── sqlite-rest │ ├── .helmignore │ ├── Chart.yaml │ ├── templates │ ├── NOTES.txt │ ├── _helpers.tpl │ ├── service.yaml │ ├── serviceaccount.yaml │ └── statefulset.yaml │ └── values.yaml ├── db.go ├── docs └── assets │ └── logo.svg ├── examples ├── bookstore │ └── data.sql └── migrations │ ├── 1_create_books_table.drop.sql │ └── 1_create_books_table.up.sql ├── fixture_test.go ├── fs.go ├── go.mod ├── go.sum ├── integration_delete_test.go ├── integration_insert_test.go ├── integration_migrate_test.go ├── integration_security_test.go ├── integration_select_test.go ├── integration_update_test.go ├── logger.go ├── main.go ├── metrics.go ├── metrics_test.go ├── migrate.go ├── query.go ├── server.go ├── server_auth.go ├── server_errors.go ├── server_security.go ├── server_utils.go └── version.go /.dockerignore: -------------------------------------------------------------------------------- 1 | bin/ 2 | docs/ 3 | charts/ 4 | 5 | data/ 6 | *.sqlite3 7 | *-shm 8 | *-wal 9 | *.token 10 | 11 | sqlite-rest 12 | test.token -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*.{yaml,yml}] 4 | indent_size = 2 -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "gomod" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: "Unit Test" 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | test: 11 | strategy: 12 | matrix: 13 | go-version: [">=1.21"] 14 | os: [ubuntu-latest] 15 | runs-on: ${{ matrix.os }} 16 | steps: 17 | - name: Checkout code 18 | uses: actions/checkout@v3 19 | - name: Install Go 20 | uses: actions/setup-go@v3 21 | with: 22 | go-version: ${{ matrix.go-version }} 23 | - name: Test 24 | run: | 25 | go test -v "./..." 26 | -------------------------------------------------------------------------------- /.github/workflows/image.yml: -------------------------------------------------------------------------------- 1 | name: Publish Image 2 | 3 | on: 4 | release: 5 | action: released 6 | # NOTE: this is for testing the workflow 7 | push: 8 | branches: ["release-test", "main"] 9 | workflow_dispatch: {} 10 | 11 | env: 12 | REGISTRY: ghcr.io 13 | IMAGE_NAME: ${{ github.repository }}/server 14 | 15 | jobs: 16 | build-and-push-image: 17 | runs-on: ubuntu-latest 18 | permissions: 19 | contents: read 20 | packages: write 21 | 22 | steps: 23 | - name: Log in to the Container registry 24 | uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9 25 | with: 26 | registry: ${{ env.REGISTRY }} 27 | username: ${{ github.actor }} 28 | password: ${{ secrets.GITHUB_TOKEN }} 29 | 30 | - name: Extract metadata (tags, labels) for Docker 31 | id: meta 32 | uses: docker/metadata-action@98669ae865ea3cffbcbaa878cf57c20bbf1c6c38 33 | with: 34 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 35 | 36 | - name: Build and push Docker image 37 | uses: docker/build-push-action@ad44023a93711e3deb337508980b4b5e9bcdc5dc 38 | with: 39 | file: ./Dockerfile.server 40 | push: true 41 | tags: ${{ steps.meta.outputs.tags }} 42 | labels: ${{ steps.meta.outputs.labels }} 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | bin/ 2 | 3 | data/ 4 | *.sqlite3 5 | *-shm 6 | *-wal 7 | *.token 8 | 9 | sqlite-rest 10 | test.token -------------------------------------------------------------------------------- /Dockerfile.server: -------------------------------------------------------------------------------- 1 | FROM golang:1.21 as builder 2 | 3 | WORKDIR /workspace 4 | # Copy the Go Modules manifests 5 | COPY go.mod go.mod 6 | COPY go.sum go.sum 7 | # cache deps before building and copying source so that we don't need to re-download as much 8 | # and so that source changes don't invalidate our downloaded layer 9 | RUN go mod download 10 | 11 | COPY . . 12 | 13 | # Build 14 | RUN GOOS=linux CGO_ENABLED=1 GOARCH=amd64 \ 15 | go build -trimpath -v -x -o bin/sqlite-rest ./ 16 | 17 | FROM docker.io/library/debian:stable-slim 18 | 19 | RUN mkdir -p /workspace 20 | 21 | WORKDIR /workspace 22 | 23 | COPY --from=builder /workspace/bin/sqlite-rest /bin/sqlite-rest 24 | 25 | ENTRYPOINT [ "/bin/sqlite-rest" ] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2022 b4fun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL = bash 2 | 3 | all: help 4 | 5 | ##@ General 6 | 7 | # The help target prints out all targets with their descriptions organized 8 | # beneath their categories. The categories are represented by '##@' and the 9 | # target descriptions by '##'. The awk commands is responsible for reading the 10 | # entire set of makefiles included in this invocation, looking for lines of the 11 | # file as xyz: ## something, and then pretty-format the target and help. Then, 12 | # if there's a line with ##@ something, that gets pretty-printed as a category. 13 | # More info on the usage of ANSI control characters for terminal formatting: 14 | # https://en.wikipedia.org/wiki/ANSI_escape_code#SGR_parameters 15 | # More info on the awk command: 16 | # http://linuxcommand.org/lc3_adv_awk.php 17 | 18 | help: ## Display this help. 19 | @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) 20 | 21 | ##@ Developement 22 | 23 | fmt: ## Run go fmt against code. 24 | go fmt ./... 25 | 26 | vet: ## Run go vet against code. 27 | go vet ./... 28 | 29 | build-server: fmt vet ## Build server. 30 | go build . 31 | 32 | run-server: build-server ## Run server. 33 | echo -n "test" > local_dev.token 34 | ./sqlite-rest serve --db-dsn ./test.sqlite3?_journal_mode=WAL --http-addr 127.0.0.1:8080 --metrics-addr 127.0.0.1:8081 --pprof-addr 127.0.0.1:8082 --log-devel --log-level 12 --auth-token-file local_dev.token --security-allow-table fruit 35 | 36 | run-migrate: build-server ## Run migration. 37 | ./sqlite-rest migrate --db-dsn ./test.sqlite3?_journal_mode=WAL --log-devel --log-level 12 ./data 38 | 39 | ##@ Build 40 | 41 | DOCKER_CMD ?= docker 42 | IMG_TAG ?= latest 43 | IMG_PREFIX ?= ghcr.io/b4fun/sqlite-rest 44 | IMG_BUILD_OPTS ?= --platform=linux/amd64 45 | 46 | build: fmt vet ## Build binary. 47 | go build -o ./sqlite-rest . 48 | 49 | build-image: build-image-server ## Build docker images. 50 | 51 | build-image-server: ## Build server docker image. 52 | ${DOCKER_CMD} build ${IMG_BUILD_OPTS} \ 53 | -f Dockerfile.server \ 54 | -t ${IMG_PREFIX}/server:${IMG_TAG} . 55 | 56 | push-image: ## Push docker images. 57 | ${DOCKER_CMD} push ${IMG_PREFIX}/server:${IMG_TAG} -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | 4 | 5 |

6 | 7 |

8 | Serve a RESTful API from any SQLite database 9 |

10 | 11 | **sqlite-rest** is similar to [PostgREST][postgrest], but for [SQLite][sqlite]. It's a standalone web server that adds a RESTful API to any SQLite database. 12 | 13 | [PostgREST]: https://postgrest.org/en/stable/ 14 | [SQLite]: https://www.sqlite.org/ 15 | 16 | ## Installation 17 | 18 | ### Build From Source 19 | 20 | ``` 21 | $ go install github.com/b4fun/sqlite-rest@latest 22 | $ sqlite-rest 23 | 24 | ``` 25 | 26 | ### Using docker image 27 | 28 | ``` 29 | $ docker run -it --rm ghcr.io/b4fun/sqlite-rest/server:main 30 | 31 | ``` 32 | 33 | ## Quick Start 34 | 35 | Suppose we are serving a book store database with the following schema: 36 | 37 | ```sql 38 | CREATE TABLE books ( 39 | id INTEGER PRIMARY KEY, 40 | title TEXT NOT NULL, 41 | author TEXT NOT NULL, 42 | price REAL NOT NULL 43 | ); 44 | ``` 45 | 46 | ### Create a database 47 | 48 | ``` 49 | $ sqlite3 bookstore.sqlite3 < examples/bookstore/data.sql 50 | ``` 51 | 52 | ### Start server 53 | 54 | ``` 55 | $ echo -n "topsecret" > test.token 56 | $ sqlite-rest serve --auth-token-file test.token --security-allow-table books --db-dsn ./bookstore.sqlite3 57 | {"level":"info","ts":1672528510.825417,"logger":"db-server","caller":"sqlite-rest/server.go:121","msg":"server started","addr":":8080"} 58 | ... 59 | ``` 60 | 61 | ### Generate authentication token 62 | 63 | **NOTE: the following steps create a sample token for testing only, please use a strong password in production.** 64 | 65 | - Visit https://jwt.io/ 66 | - Choose `HS256` as the algorithm 67 | - Enter `topsecret` as the secret 68 | - Copy the encoded JWT from the encoded output 69 | - Export the token as an environment variable 70 | 71 | ``` 72 | $ export AUTH_TOKEN= 73 | ``` 74 | 75 | 76 | ### Querying 77 | 78 | **Querying by book id** 79 | 80 | ``` 81 | $ curl -H "Authorization: Bearer $AUTH_TOKEN" http://127.0.0.1:8080/books?id=eq.1 82 | [ 83 | { 84 | "author": "Stephen King", 85 | "id": 1, 86 | "price": 23.54, 87 | "title": "Fairy Tale" 88 | } 89 | ] 90 | ``` 91 | 92 | **Querying by book price** 93 | 94 | ``` 95 | $ curl -H "Authorization: Bearer $AUTH_TOKEN" http://127.0.0.1:8080/books?price=lt.10 96 | [ 97 | { 98 | "author": "Alice Hoffman", 99 | "id": 2, 100 | "price": 1.99, 101 | "title": "The Bookstore Sisters: A Short Story" 102 | }, 103 | { 104 | "author": "Caroline Peckham", 105 | "id": 4, 106 | "price": 8.99, 107 | "title": "Zodiac Academy 8: Sorrow and Starlight" 108 | } 109 | ] 110 | ``` 111 | 112 | ## Features 113 | 114 | ### Parity with PostgRest 115 | 116 | sqlite-rest aims to implement the same API as PostgRest. But currently not all of them are being implemented. Below is a list that features supported in sqlite-rest. If you need support for implementing a feature absent in the list, feel free to create an issue :smile: 117 | 118 | - Tables and Views 119 | - [x] Horizontal Filtering (Rows) 120 | - [x] Vrtical Filtering (Columns) 121 | - [x] Unicode support 122 | - [x] Ordering 123 | - [x] Limit and Pagination 124 | - [x] Exact Count 125 | - Insertions 126 | - [x] Specifying Columns 127 | - [x] Updates 128 | - [x] Upsert 129 | - [x] Deletions 130 | 131 | ### Authentication 132 | 133 | sqlite-rest provides built-in JWT based authentication. To use `HS256` / `HS384` / `HS512` algorithm, please specific the token file to read from via `--auth-token-file` flag. To use `RS256` / `RS384` / `RS512` algorithm, please specify the public key via `--auth-rsa-public-key` flag. 134 | 135 | ### Tables/Views Access 136 | 137 | By default, sqlite-rest exposes **no** tables/views from accessing. To allow access to specific tables/views, please use `--security-allow-table` flag: 138 | 139 | **one table** 140 | 141 | ``` 142 | --security-allow-table books 143 | ``` 144 | 145 | **multiple tables** 146 | 147 | ``` 148 | --security-allow-table books,authors 149 | ``` 150 | 151 | ### Metrics 152 | 153 | sqlite-rest exposes metrics via [Prometheus][prometheus] format. By default, these metrics are exposed via `:8081/metrics` endpoint. To change the endpoint, please use `--metrics-addr` flag. To disable metrics, specific `--metrics-addr` to `""`. 154 | 155 | Recorded metrics can be found in [metrics.go](metrics.go). 156 | 157 | [prometheus]: https://prometheus.io/ 158 | 159 | ### Database Migrations 160 | 161 | sqlite-rest supports database migrations via [golang-migrate][golang-migrate]. 162 | 163 | **Apply migrations** 164 | 165 | ``` 166 | $ sqlite-rest migrate --db-dsn ./bookstore.sqlite3 ./examples/migrations 167 | {"level":"info","ts":1672614524.2731035,"logger":"db-migrator.up","caller":"sqlite-rest/migrate.go:136","msg":"applying operation"} 168 | {"level":"info","ts":1672614524.3081956,"logger":"db-migrator.up","caller":"sqlite-rest/migrate.go:140","msg":"applied operation"} 169 | ``` 170 | 171 | **Rollback migrations** 172 | 173 | ``` 174 | $ sqlite-rest migrate --db-dsn ./bookstore.sqlite3 --direction down --step 1 ./examples/migrations 175 | ``` 176 | 177 | [golang-migrate]: https://github.com/golang-migrate/migrate 178 | 179 | ## License 180 | 181 | MIT -------------------------------------------------------------------------------- /charts/sqlite-rest/.helmignore: -------------------------------------------------------------------------------- 1 | # Patterns to ignore when building packages. 2 | # This supports shell glob matching, relative path matching, and 3 | # negation (prefixed with !). Only one pattern per line. 4 | .DS_Store 5 | # Common VCS dirs 6 | .git/ 7 | .gitignore 8 | .bzr/ 9 | .bzrignore 10 | .hg/ 11 | .hgignore 12 | .svn/ 13 | # Common backup files 14 | *.swp 15 | *.bak 16 | *.tmp 17 | *.orig 18 | *~ 19 | # Various IDEs 20 | .project 21 | .idea/ 22 | *.tmproj 23 | .vscode/ 24 | -------------------------------------------------------------------------------- /charts/sqlite-rest/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: sqlite-rest 3 | description: A Helm chart for sqlite-rest 4 | 5 | # A chart can be either an 'application' or a 'library' chart. 6 | # 7 | # Application charts are a collection of templates that can be packaged into versioned archives 8 | # to be deployed. 9 | # 10 | # Library charts provide useful utilities or functions for the chart developer. They're included as 11 | # a dependency of application charts to inject those utilities and functions into the rendering 12 | # pipeline. Library charts do not define any templates and therefore cannot be deployed. 13 | type: application 14 | 15 | # This is the chart version. This version number should be incremented each time you make changes 16 | # to the chart and its templates, including the app version. 17 | # Versions are expected to follow Semantic Versioning (https://semver.org/) 18 | version: 0.1.1 19 | 20 | # This is the version number of the application being deployed. This version number should be 21 | # incremented each time you make changes to the application. Versions are not expected to 22 | # follow Semantic Versioning. They should reflect the version the application is using. 23 | # It is recommended to use it with quotes. 24 | appVersion: "1.16.0" 25 | 26 | sources: 27 | - "https://github.com/b4fun/sqlite-rest" 28 | -------------------------------------------------------------------------------- /charts/sqlite-rest/templates/NOTES.txt: -------------------------------------------------------------------------------- 1 | 1. Get the application URL by running these commands: 2 | {{- if contains "NodePort" .Values.service.type }} 3 | export NODE_PORT=$(kubectl get --namespace {{ .Release.Namespace }} -o jsonpath="{.spec.ports[0].nodePort}" services {{ include "sqlite-rest.fullname" . }}) 4 | export NODE_IP=$(kubectl get nodes --namespace {{ .Release.Namespace }} -o jsonpath="{.items[0].status.addresses[0].address}") 5 | echo http://$NODE_IP:$NODE_PORT 6 | {{- else if contains "LoadBalancer" .Values.service.type }} 7 | NOTE: It may take a few minutes for the LoadBalancer IP to be available. 8 | You can watch the status of by running 'kubectl get --namespace {{ .Release.Namespace }} svc -w {{ include "sqlite-rest.fullname" . }}' 9 | export SERVICE_IP=$(kubectl get svc --namespace {{ .Release.Namespace }} {{ include "sqlite-rest.fullname" . }} --template "{{"{{ range (index .status.loadBalancer.ingress 0) }}{{.}}{{ end }}"}}") 10 | echo http://$SERVICE_IP:{{ .Values.service.port }} 11 | {{- else if contains "ClusterIP" .Values.service.type }} 12 | export POD_NAME=$(kubectl get pods --namespace {{ .Release.Namespace }} -l "app.kubernetes.io/name={{ include "sqlite-rest.name" . }},app.kubernetes.io/instance={{ .Release.Name }}" -o jsonpath="{.items[0].metadata.name}") 13 | export CONTAINER_PORT=$(kubectl get pod --namespace {{ .Release.Namespace }} $POD_NAME -o jsonpath="{.spec.containers[0].ports[0].containerPort}") 14 | echo "Visit http://127.0.0.1:8080 to use your application" 15 | kubectl --namespace {{ .Release.Namespace }} port-forward $POD_NAME 8080:$CONTAINER_PORT 16 | {{- end }} 17 | -------------------------------------------------------------------------------- /charts/sqlite-rest/templates/_helpers.tpl: -------------------------------------------------------------------------------- 1 | {{/* 2 | Expand the name of the chart. 3 | */}} 4 | {{- define "sqlite-rest.name" -}} 5 | {{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} 6 | {{- end }} 7 | 8 | {{/* 9 | Create a default fully qualified app name. 10 | We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec). 11 | If release name contains chart name it will be used as a full name. 12 | */}} 13 | {{- define "sqlite-rest.fullname" -}} 14 | {{- if .Values.fullnameOverride }} 15 | {{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} 16 | {{- else }} 17 | {{- $name := default .Chart.Name .Values.nameOverride }} 18 | {{- if contains $name .Release.Name }} 19 | {{- .Release.Name | trunc 63 | trimSuffix "-" }} 20 | {{- else }} 21 | {{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} 22 | {{- end }} 23 | {{- end }} 24 | {{- end }} 25 | 26 | {{/* 27 | Create chart name and version as used by the chart label. 28 | */}} 29 | {{- define "sqlite-rest.chart" -}} 30 | {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} 31 | {{- end }} 32 | 33 | {{/* 34 | Common labels 35 | */}} 36 | {{- define "sqlite-rest.labels" -}} 37 | helm.sh/chart: {{ include "sqlite-rest.chart" . }} 38 | {{ include "sqlite-rest.selectorLabels" . }} 39 | {{- if .Chart.AppVersion }} 40 | app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} 41 | {{- end }} 42 | app.kubernetes.io/managed-by: {{ .Release.Service }} 43 | {{- end }} 44 | 45 | {{/* 46 | Selector labels 47 | */}} 48 | {{- define "sqlite-rest.selectorLabels" -}} 49 | app.kubernetes.io/name: {{ include "sqlite-rest.name" . }} 50 | app.kubernetes.io/instance: {{ .Release.Name }} 51 | {{- end }} -------------------------------------------------------------------------------- /charts/sqlite-rest/templates/service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: {{ include "sqlite-rest.fullname" . }} 5 | labels: 6 | {{- include "sqlite-rest.labels" . | nindent 4 }} 7 | spec: 8 | type: {{ .Values.service.type }} 9 | ports: 10 | - port: {{ .Values.service.port }} 11 | targetPort: http 12 | protocol: TCP 13 | name: http 14 | selector: 15 | {{- include "sqlite-rest.selectorLabels" . | nindent 4 }} 16 | -------------------------------------------------------------------------------- /charts/sqlite-rest/templates/serviceaccount.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: ServiceAccount 3 | metadata: 4 | name: {{ .Values.serviceAccount.name }} 5 | labels: 6 | {{- include "sqlite-rest.labels" . | nindent 4 }} -------------------------------------------------------------------------------- /charts/sqlite-rest/templates/statefulset.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: StatefulSet 3 | metadata: 4 | name: {{ include "sqlite-rest.fullname" . }} 5 | labels: 6 | {{- include "sqlite-rest.labels" . | nindent 4 }} 7 | spec: 8 | selector: 9 | matchLabels: 10 | {{- include "sqlite-rest.selectorLabels" . | nindent 6 }} 11 | serviceName: {{ include "sqlite-rest.name" . }} 12 | 13 | replicas: 1 14 | 15 | volumeClaimTemplates: 16 | {{- if .Values.data.enabled }} 17 | - metadata: 18 | name: data 19 | spec: 20 | accessModes: ["ReadWriteOnce"] 21 | {{- if .Values.data.storageClassName }} 22 | storageClassName: {{ .Values.data.storageClassName }} 23 | {{- end }} 24 | resources: 25 | {{- toYaml .Values.data.resource | nindent 10 }} 26 | {{- end }} 27 | 28 | template: 29 | metadata: 30 | labels: 31 | {{- include "sqlite-rest.selectorLabels" . | nindent 8 }} 32 | spec: 33 | volumes: 34 | {{- if .Values.server.secretNameAuthToken }} 35 | - name: auth-token-file 36 | secret: 37 | secretName: {{ .Values.server.secretNameAuthToken }} 38 | {{- end }} 39 | {{- if .Values.server.secretNameAuthRSAPublicKey }} 40 | - name: auth-rsa-public-key 41 | secret: 42 | secretName: {{ .Values.server.secretNameAuthRSAPublicKey }} 43 | {{- end }} 44 | {{- if (not .Values.data.enabled) }} 45 | - name: data 46 | emptyDir: {} 47 | {{- end }} 48 | {{- if .Values.migrations.enabled }} 49 | - name: migrations 50 | configMap: 51 | name: {{ .Values.migrations.configMapName }} 52 | {{- end }} 53 | {{- if .Values.litestream.enabled }} 54 | - name: litestream-config 55 | secret: 56 | secretName: {{ .Values.litestream.secretName }} 57 | {{- end }} 58 | initContainers: 59 | {{- if .Values.litestream.enabled }} 60 | - name: litestream-init 61 | image: {{ .Values.litestream.image.repository }}:{{ .Values.litestream.image.tag }} 62 | imagePullPolicy: {{ .Values.litestream.image.pullPolicy }} 63 | args: 64 | - restore 65 | - '-if-db-not-exists' 66 | - '-if-replica-exists' 67 | - '-v' 68 | - '/data/db.sqlite3' 69 | volumeMounts: 70 | - name: data 71 | mountPath: /data 72 | - name: litestream-config 73 | mountPath: /etc/litestream.yml 74 | subPath: litestream.yml 75 | {{- end }} 76 | {{- if .Values.migrations.enabled }} 77 | - name: migrations 78 | image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}" 79 | imagePullPolicy: {{ .Values.image.pullPolicy }} 80 | command: 81 | - sqlite-rest 82 | - migrate 83 | args: 84 | - /migrations 85 | - --db-dsn=/data/db.sqlite3 86 | - --log-level={{ .Values.server.logLevel }} 87 | - --log-devel={{ .Values.server.useDevelLog }} 88 | volumeMounts: 89 | - name: data 90 | mountPath: /data 91 | - name: migrations 92 | mountPath: /migrations 93 | {{- end }} 94 | containers: 95 | - name: server 96 | image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}" 97 | imagePullPolicy: {{ .Values.image.pullPolicy }} 98 | command: 99 | - sqlite-rest 100 | - serve 101 | args: 102 | - --db-dsn=/data/db.sqlite3 103 | - --security-allow-table={{ .Values.server.securityAllowTable }} 104 | - --log-level={{ .Values.server.logLevel }} 105 | - --log-devel={{ .Values.server.useDevelLog }} 106 | - --metrics-server=:8081 107 | {{- if .Values.server.secretNameAuthToken }} 108 | - --auth-token-file=/auth-token 109 | {{- end }} 110 | {{- if .Values.server.secretNameAuthRSAPublicKey }} 111 | - --auth-rsa-public-key=/auth-rsa-public-key 112 | {{- end }} 113 | ports: 114 | - name: http 115 | containerPort: 8080 116 | protocol: TCP 117 | - name: metrics 118 | containerPort: 8081 119 | protocol: TCP 120 | resources: 121 | {{- toYaml .Values.resources | nindent 12 }} 122 | volumeMounts: 123 | - name: data 124 | mountPath: /data 125 | {{- if .Values.server.secretNameAuthToken }} 126 | - name: auth-token-file 127 | mountPath: /auth-token 128 | subPath: auth.yaml 129 | readOnly: true 130 | {{- end }} 131 | {{- if .Values.server.secretNameAuthRSAPublicKey }} 132 | - name: auth-rsa-public-key 133 | mountPath: /auth-rsa-public-key 134 | subPath: auth.yaml 135 | readOnly: true 136 | {{- end }} 137 | {{- if .Values.litestream.enabled }} 138 | - name: litestream 139 | image: {{ .Values.litestream.image.repository }}:{{ .Values.litestream.image.tag }} 140 | imagePullPolicy: {{ .Values.litestream.image.pullPolicy }} 141 | args: 142 | - replicate 143 | volumeMounts: 144 | - name: data 145 | mountPath: /data 146 | - name: litestream-config 147 | mountPath: /etc/litestream.yml 148 | subPath: litestream.yml 149 | ports: 150 | - name: ls-metrics 151 | containerPort: 9090 152 | {{- end }} 153 | {{- with .Values.nodeSelector }} 154 | nodeSelector: 155 | {{- toYaml . | nindent 8 }} 156 | {{- end }} -------------------------------------------------------------------------------- /charts/sqlite-rest/values.yaml: -------------------------------------------------------------------------------- 1 | server: 2 | logLevel: 5 3 | useDevelLog: true 4 | secretNameAuthToken: "" 5 | secretNameAuthRSAPublicKey: "" 6 | securityAllowTable: "" 7 | 8 | migrations: 9 | enabled: false 10 | configMapName: sqlite-rest-migrations 11 | 12 | litestream: 13 | enabled: false 14 | image: 15 | repository: litestream/litestream 16 | pullPolicy: IfNotPresent 17 | tag: "0.3.6" 18 | secretName: sqlite-rest-litestream-config 19 | 20 | data: 21 | enabled: false 22 | storageClassName: "" 23 | resource: 24 | requests: 25 | storage: "100Mi" 26 | 27 | labels: 28 | build4.fun/app: sqlite-rest 29 | 30 | image: 31 | repository: ghcr.io/b4fun/sqlite-rest/server 32 | pullPolicy: IfNotPresent 33 | tag: "main" 34 | 35 | service: 36 | type: ClusterIP 37 | port: 8080 38 | 39 | serviceAccount: 40 | name: sqlite-rest 41 | 42 | resources: 43 | limits: 44 | memory: 512Mi 45 | requests: 46 | cpu: 100m 47 | memory: 20Mi 48 | 49 | nodeSelector: 50 | kubernetes.io/os: linux 51 | -------------------------------------------------------------------------------- /db.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/jmoiron/sqlx" 7 | _ "github.com/mattn/go-sqlite3" 8 | "github.com/spf13/cobra" 9 | ) 10 | 11 | func openDB(cmd *cobra.Command) (*sqlx.DB, error) { 12 | dsn, err := cmd.Flags().GetString(cliFlagDBDSN) 13 | if err != nil { 14 | return nil, fmt.Errorf("read %s: %w", cliFlagDBDSN, err) 15 | } 16 | 17 | db, err := sqlx.Open("sqlite3", dsn) 18 | if err != nil { 19 | return nil, err 20 | } 21 | 22 | return db, nil 23 | } 24 | -------------------------------------------------------------------------------- /docs/assets/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /examples/bookstore/data.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE books ( 2 | id SERIAL PRIMARY KEY, 3 | title VARCHAR(255) NOT NULL, 4 | author VARCHAR(255) NOT NULL, 5 | price NUMERIC(5,2) NOT NULL 6 | ); 7 | 8 | INSERT INTO books (id, title, author, price) VALUES 9 | (1, "Fairy Tale", "Stephen King", 23.54), 10 | (2, "The Bookstore Sisters: A Short Story", "Alice Hoffman", 1.99), 11 | (3, "The Invisible Life of Addie LaRue", "V.E. Schwab", 14.99), 12 | (4, "Zodiac Academy 8: Sorrow and Starlight", "Caroline Peckham", 8.99), 13 | (5, "He Who Fights with Monsters 8: A LitRPG Adventure", "Shirtaloon, Travis Deverell", 51.99); -------------------------------------------------------------------------------- /examples/migrations/1_create_books_table.drop.sql: -------------------------------------------------------------------------------- 1 | DROP TABLE IF EXISTS books; -------------------------------------------------------------------------------- /examples/migrations/1_create_books_table.up.sql: -------------------------------------------------------------------------------- 1 | DROP TABLE IF EXISTS books; 2 | CREATE TABLE books ( 3 | id SERIAL PRIMARY KEY, 4 | title VARCHAR(255) NOT NULL, 5 | author VARCHAR(255) NOT NULL, 6 | price NUMERIC(5,2) NOT NULL 7 | ); -------------------------------------------------------------------------------- /fixture_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/rsa" 6 | "crypto/x509" 7 | "encoding/json" 8 | "encoding/pem" 9 | "fmt" 10 | "io" 11 | "net/http" 12 | "net/http/httptest" 13 | "net/url" 14 | "os" 15 | "path/filepath" 16 | "testing" 17 | 18 | "github.com/go-logr/logr" 19 | "github.com/golang-jwt/jwt" 20 | "github.com/jmoiron/sqlx" 21 | "github.com/stretchr/testify/assert" 22 | "github.com/supabase/postgrest-go" 23 | "k8s.io/klog/v2/ktesting" 24 | ) 25 | 26 | var enabledTestTables = []string{"test", "test_view"} 27 | 28 | type TestContext struct { 29 | server *httptest.Server 30 | db *sqlx.DB 31 | cleanUpDB func(t testing.TB) 32 | authToken string 33 | } 34 | 35 | func NewTestContextWithDB( 36 | t testing.TB, 37 | handler http.Handler, 38 | db *sqlx.DB, 39 | cleanUpDB func(t testing.TB), 40 | authToken string, 41 | ) *TestContext { 42 | rv := &TestContext{ 43 | server: httptest.NewServer(handler), 44 | db: db, 45 | cleanUpDB: cleanUpDB, 46 | authToken: authToken, 47 | } 48 | 49 | return rv 50 | } 51 | 52 | func (tc *TestContext) CleanUp(t testing.TB) { 53 | if tc.cleanUpDB != nil { 54 | tc.cleanUpDB(t) 55 | } 56 | 57 | tc.server.Close() 58 | } 59 | 60 | func (tc *TestContext) DB() *sqlx.DB { 61 | return tc.db 62 | } 63 | 64 | func (tc *TestContext) ServerURL() *url.URL { 65 | u, err := url.Parse(tc.server.URL) 66 | if err != nil { 67 | // shouldn't happen 68 | panic(fmt.Sprintf("failed to parse server url: %s", err)) 69 | } 70 | return u 71 | } 72 | 73 | func (tc *TestContext) Client() *postgrest.Client { 74 | rv := postgrest.NewClient( 75 | tc.ServerURL().String(), 76 | "http", 77 | nil, 78 | ) 79 | 80 | if tc.authToken != "" { 81 | rv = rv.TokenAuth(tc.authToken) 82 | } 83 | 84 | return rv 85 | } 86 | 87 | func (tc *TestContext) HTTPClient() *http.Client { 88 | return &http.Client{} 89 | } 90 | 91 | func (tc *TestContext) NewRequest( 92 | t testing.TB, 93 | method string, path string, 94 | body io.Reader, 95 | ) *http.Request { 96 | req, err := http.NewRequest(method, tc.ServerURL().String()+"/"+path, body) 97 | assert.NoError(t, err) 98 | 99 | if tc.authToken != "" { 100 | req.Header.Set("Authorization", "Bearer "+tc.authToken) 101 | } 102 | return req 103 | } 104 | 105 | func (tc *TestContext) ExecuteRequest(t testing.TB, req *http.Request) *http.Response { 106 | resp, err := tc.HTTPClient().Do(req) 107 | assert.NoError(t, err) 108 | return resp 109 | } 110 | 111 | func (tc *TestContext) ExecuteSQL(t testing.TB, stmt string, args ...interface{}) { 112 | _, err := tc.DB().Exec(stmt, args...) 113 | assert.NoError(t, err) 114 | } 115 | 116 | func (tc *TestContext) DecodeResult(t testing.TB, res []byte, des interface{}) { 117 | err := json.Unmarshal(res, des) 118 | assert.NoError(t, err) 119 | } 120 | 121 | func createTestLogger(t testing.TB) logr.Logger { 122 | return ktesting.NewLogger(t, ktesting.NewConfig(ktesting.Verbosity(12))) 123 | } 124 | 125 | func createTestContextUsingInMemoryDB(t testing.TB) *TestContext { 126 | t.Log("creating in-memory db") 127 | db, err := sqlx.Open("sqlite3", ":memory:") 128 | if err != nil { 129 | t.Fatal(err) 130 | return nil 131 | } 132 | 133 | t.Log("creating server") 134 | serverOpts := &ServerOptions{ 135 | Logger: createTestLogger(t).WithName("test"), 136 | Queryer: db, 137 | Execer: db, 138 | } 139 | serverOpts.AuthOptions.disableAuth = true 140 | serverOpts.SecurityOptions.EnabledTableOrViews = enabledTestTables 141 | server, err := NewServer(serverOpts) 142 | if err != nil { 143 | t.Fatal(err) 144 | return nil 145 | } 146 | 147 | return NewTestContextWithDB( 148 | t, 149 | server.server.Handler, 150 | db, 151 | func(t testing.TB) { 152 | if err := db.Close(); err != nil { 153 | t.Errorf("closing in-memory db: %s", err) 154 | } 155 | }, 156 | "", 157 | ) 158 | } 159 | 160 | func createTestContextWithHMACTokenAuth(t testing.TB) *TestContext { 161 | t.Log("creating test dir") 162 | dir, err := os.MkdirTemp("", "sqlite-rest-test") 163 | if err != nil { 164 | t.Fatal(err) 165 | return nil 166 | } 167 | 168 | t.Log("creating test token file") 169 | testToken := []byte("test-token") 170 | testTokenFile := filepath.Join(dir, "token") 171 | if err := os.WriteFile(testTokenFile, testToken, 0644); err != nil { 172 | t.Fatal(err) 173 | return nil 174 | } 175 | 176 | authToken := jwt.NewWithClaims(jwt.SigningMethodHS256, &jwt.StandardClaims{}) 177 | authTokenString, err := authToken.SignedString(testToken) 178 | if err != nil { 179 | t.Fatal(err) 180 | return nil 181 | } 182 | 183 | db, err := sqlx.Open("sqlite3", "//"+filepath.Join(dir, "test.db")) 184 | if err != nil { 185 | t.Fatal(err) 186 | return nil 187 | } 188 | 189 | t.Log("creating server") 190 | serverOpts := &ServerOptions{ 191 | Logger: createTestLogger(t).WithName("test"), 192 | Queryer: db, 193 | Execer: db, 194 | } 195 | serverOpts.AuthOptions.TokenFilePath = testTokenFile 196 | serverOpts.SecurityOptions.EnabledTableOrViews = enabledTestTables 197 | server, err := NewServer(serverOpts) 198 | if err != nil { 199 | t.Fatal(err) 200 | return nil 201 | } 202 | 203 | return NewTestContextWithDB( 204 | t, 205 | server.server.Handler, 206 | db, 207 | func(t testing.TB) { 208 | if err := db.Close(); err != nil { 209 | t.Fatalf("closing db: %s", err) 210 | return 211 | } 212 | 213 | if err := os.RemoveAll(dir); err != nil { 214 | t.Fatalf("removing test dir %q: %s", dir, err) 215 | return 216 | } 217 | }, 218 | authTokenString, 219 | ) 220 | } 221 | 222 | func createTestContextWithRSATokenAuth(t testing.TB) *TestContext { 223 | t.Log("creating test dir") 224 | dir, err := os.MkdirTemp("", "sqlite-rest-test") 225 | if err != nil { 226 | t.Fatal(err) 227 | return nil 228 | } 229 | 230 | t.Log("creating test token file") 231 | privateKey, err := rsa.GenerateKey(rand.Reader, 1024) 232 | if err != nil { 233 | t.Fatal(err) 234 | return nil 235 | } 236 | b, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) 237 | if err != nil { 238 | t.Fatal(err) 239 | return nil 240 | } 241 | publicKeyPem := pem.EncodeToMemory(&pem.Block{ 242 | Type: "PUBLIC KEY", 243 | Bytes: b, 244 | }) 245 | 246 | testTokenFile := filepath.Join(dir, "token") 247 | if err := os.WriteFile(testTokenFile, publicKeyPem, 0644); err != nil { 248 | t.Fatal(err) 249 | return nil 250 | } 251 | 252 | authToken := jwt.NewWithClaims(jwt.SigningMethodRS256, &jwt.StandardClaims{}) 253 | authTokenString, err := authToken.SignedString(privateKey) 254 | if err != nil { 255 | t.Fatal(err) 256 | return nil 257 | } 258 | 259 | db, err := sqlx.Open("sqlite3", "//"+filepath.Join(dir, "test.db")) 260 | if err != nil { 261 | t.Fatal(err) 262 | return nil 263 | } 264 | 265 | t.Log("creating server") 266 | serverOpts := &ServerOptions{ 267 | Logger: createTestLogger(t).WithName("test"), 268 | Queryer: db, 269 | Execer: db, 270 | } 271 | serverOpts.AuthOptions.RSAPublicKeyFilePath = testTokenFile 272 | serverOpts.SecurityOptions.EnabledTableOrViews = enabledTestTables 273 | server, err := NewServer(serverOpts) 274 | if err != nil { 275 | t.Fatal(err) 276 | return nil 277 | } 278 | 279 | return NewTestContextWithDB( 280 | t, 281 | server.server.Handler, 282 | db, 283 | func(t testing.TB) { 284 | if err := db.Close(); err != nil { 285 | t.Fatalf("closing db: %s", err) 286 | return 287 | } 288 | 289 | if err := os.RemoveAll(dir); err != nil { 290 | t.Fatalf("removing test dir %q: %s", dir, err) 291 | return 292 | } 293 | }, 294 | authTokenString, 295 | ) 296 | } 297 | 298 | type MigrationTestContext struct { 299 | migrator *dbMigrator 300 | db *sqlx.DB 301 | cleanUpDB func(t testing.TB) 302 | } 303 | 304 | func (mtc *MigrationTestContext) Migrator() *dbMigrator { 305 | return mtc.migrator 306 | } 307 | 308 | func (mtc *MigrationTestContext) CleanUp(t testing.TB) { 309 | if mtc.cleanUpDB != nil { 310 | mtc.cleanUpDB(t) 311 | } 312 | } 313 | 314 | func NewMigrationTestContext( 315 | t testing.TB, 316 | migrations map[string]string, 317 | ) *MigrationTestContext { 318 | t.Log("creating test dir") 319 | dir, err := os.MkdirTemp("", "sqlite-rest-test") 320 | if err != nil { 321 | t.Fatal(err) 322 | return nil 323 | } 324 | 325 | migrationsDir := filepath.Join(dir, "migrations") 326 | if err := os.MkdirAll(migrationsDir, 0755); err != nil { 327 | t.Fatal(err) 328 | return nil 329 | } 330 | 331 | t.Log("writing migrations") 332 | for filename, content := range migrations { 333 | p := filepath.Join(migrationsDir, filename) 334 | if err := os.WriteFile(p, []byte(content), 0644); err != nil { 335 | t.Fatal(err) 336 | return nil 337 | } 338 | } 339 | 340 | t.Log("craeting in-memory db") 341 | db, err := sqlx.Open("sqlite3", "") 342 | if err != nil { 343 | t.Fatal(err) 344 | return nil 345 | } 346 | 347 | t.Log("creating migrator") 348 | migratorOpts := &MigrateOptions{ 349 | Logger: createTestLogger(t).WithName("test"), 350 | DB: db.DB, 351 | SourceDIR: migrationsDir, 352 | } 353 | 354 | migrator, err := NewMigrator(migratorOpts) 355 | if err != nil { 356 | t.Fatal(err) 357 | return nil 358 | } 359 | 360 | return &MigrationTestContext{ 361 | migrator: migrator, 362 | db: db, 363 | cleanUpDB: func(t testing.TB) { 364 | if err := db.Close(); err != nil { 365 | t.Errorf("closing in-memory db: %s", err) 366 | return 367 | } 368 | 369 | if err := os.RemoveAll(dir); err != nil { 370 | t.Fatalf("removing test dir %q: %s", dir, err) 371 | return 372 | } 373 | }, 374 | } 375 | } 376 | -------------------------------------------------------------------------------- /fs.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | "sync" 6 | ) 7 | 8 | func readFileWithStatCache(file string) func() ([]byte, error) { 9 | mu := new(sync.RWMutex) 10 | var ( 11 | lastReadContent []byte 12 | lastStat os.FileInfo 13 | ) 14 | 15 | fast := func() (bool, []byte, error) { 16 | stat, err := os.Stat(file) 17 | if err != nil { 18 | return false, nil, err 19 | } 20 | 21 | mu.RLock() 22 | defer mu.RUnlock() 23 | 24 | if lastStat == nil { 25 | // no cache 26 | return false, nil, nil 27 | } 28 | 29 | if lastStat.ModTime() == stat.ModTime() { 30 | return true, lastReadContent, nil 31 | } 32 | 33 | // mod time changed 34 | return false, nil, nil 35 | } 36 | 37 | slow := func() ([]byte, error) { 38 | mu.Lock() 39 | defer mu.Unlock() 40 | 41 | stat, err := os.Stat(file) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | if lastStat != nil && lastStat.ModTime() == stat.ModTime() { 47 | return lastReadContent, nil 48 | } 49 | 50 | lastStat = stat 51 | lastReadContent = nil 52 | 53 | content, err := os.ReadFile(file) 54 | if err != nil { 55 | return nil, err 56 | } 57 | lastReadContent = content 58 | 59 | return content, nil 60 | } 61 | 62 | return func() ([]byte, error) { 63 | readFromCache, content, err := fast() 64 | if err != nil { 65 | return nil, err 66 | } 67 | if readFromCache { 68 | return content, nil 69 | } 70 | 71 | return slow() 72 | } 73 | } -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/b4fun/sqlite-rest 2 | 3 | go 1.21 4 | 5 | require ( 6 | github.com/go-chi/chi/v5 v5.2.1 7 | github.com/go-chi/cors v1.2.1 8 | github.com/go-logr/logr v1.4.2 9 | github.com/go-logr/zapr v1.3.0 10 | github.com/golang-jwt/jwt v3.2.2+incompatible 11 | github.com/golang-migrate/migrate/v4 v4.17.1 12 | github.com/jmoiron/sqlx v1.4.0 13 | github.com/mattn/go-sqlite3 v1.14.24 14 | github.com/prometheus/client_golang v1.20.5 15 | github.com/spf13/cobra v1.9.1 16 | github.com/spf13/pflag v1.0.6 17 | github.com/stretchr/testify v1.10.0 18 | github.com/supabase/postgrest-go v0.0.7 19 | go.uber.org/zap v1.27.0 20 | k8s.io/klog/v2 v2.130.1 21 | ) 22 | 23 | require ( 24 | github.com/beorn7/perks v1.0.1 // indirect 25 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 26 | github.com/davecgh/go-spew v1.1.1 // indirect 27 | github.com/hashicorp/errwrap v1.1.0 // indirect 28 | github.com/hashicorp/go-multierror v1.1.1 // indirect 29 | github.com/inconshreveable/mousetrap v1.1.0 // indirect 30 | github.com/klauspost/compress v1.17.9 // indirect 31 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect 32 | github.com/pmezard/go-difflib v1.0.0 // indirect 33 | github.com/prometheus/client_model v0.6.1 // indirect 34 | github.com/prometheus/common v0.55.0 // indirect 35 | github.com/prometheus/procfs v0.15.1 // indirect 36 | github.com/rogpeppe/go-internal v1.11.0 // indirect 37 | go.uber.org/atomic v1.7.0 // indirect 38 | go.uber.org/multierr v1.10.0 // indirect 39 | golang.org/x/sys v0.22.0 // indirect 40 | google.golang.org/protobuf v1.34.2 // indirect 41 | gopkg.in/yaml.v3 v3.0.1 // indirect 42 | ) 43 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= 2 | filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= 3 | github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= 4 | github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= 5 | github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= 6 | github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 7 | github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= 8 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 9 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 10 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 11 | github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8= 12 | github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= 13 | github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= 14 | github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= 15 | github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= 16 | github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= 17 | github.com/go-logr/zapr v1.3.0 h1:XGdV8XW8zdwFiwOA2Dryh1gj2KRQyOOoNmBy4EplIcQ= 18 | github.com/go-logr/zapr v1.3.0/go.mod h1:YKepepNBd1u/oyhd/yQmtjVXmm9uML4IXUgMOwR8/Gg= 19 | github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= 20 | github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= 21 | github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= 22 | github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= 23 | github.com/golang-migrate/migrate/v4 v4.17.1 h1:4zQ6iqL6t6AiItphxJctQb3cFqWiSpMnX7wLTPnnYO4= 24 | github.com/golang-migrate/migrate/v4 v4.17.1/go.mod h1:m8hinFyWBn0SA4QKHuKh175Pm9wjmxj3S2Mia7dbXzM= 25 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 26 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 27 | github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= 28 | github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= 29 | github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= 30 | github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= 31 | github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= 32 | github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= 33 | github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= 34 | github.com/jarcoal/httpmock v1.1.0 h1:F47ChZj1Y2zFsCXxNkBPwNNKnAyOATcdQibk0qEdVCE= 35 | github.com/jarcoal/httpmock v1.1.0/go.mod h1:ATjnClrvW/3tijVmpL/va5Z3aAyGvqU3gCT8nX0Txik= 36 | github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= 37 | github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= 38 | github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= 39 | github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= 40 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 41 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 42 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 43 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 44 | github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= 45 | github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= 46 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= 47 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 48 | github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= 49 | github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= 50 | github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= 51 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= 52 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= 53 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 54 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 55 | github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y= 56 | github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= 57 | github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= 58 | github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= 59 | github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc= 60 | github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= 61 | github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= 62 | github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= 63 | github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= 64 | github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= 65 | github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= 66 | github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= 67 | github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= 68 | github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= 69 | github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= 70 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 71 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 72 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 73 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 74 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 75 | github.com/supabase/postgrest-go v0.0.7 h1:wkOzrndF/KliPEVHM84lNnET7ZFjAk1OPpAxz8hgzRs= 76 | github.com/supabase/postgrest-go v0.0.7/go.mod h1:sqnMeRGv0p8BzJX7busTdpT51tRdJHX9R5kd8oziovo= 77 | go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= 78 | go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= 79 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= 80 | go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= 81 | go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= 82 | go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= 83 | go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= 84 | go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= 85 | golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= 86 | golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 87 | google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= 88 | google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= 89 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 90 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 91 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 92 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 93 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 94 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 95 | k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= 96 | k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= 97 | -------------------------------------------------------------------------------- /integration_delete_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func testDelete_SingleTable(t *testing.T, createTestContext func(t testing.TB) *TestContext) { 10 | t.Run("NoTable", func(t *testing.T) { 11 | t.Parallel() 12 | tc := createTestContext(t) 13 | defer tc.CleanUp(t) 14 | 15 | client := tc.Client() 16 | _, _, err := client.From("test").Delete("", "").Execute() 17 | assert.Error(t, err) 18 | assert.Contains(t, err.Error(), "no such table: test") 19 | }) 20 | 21 | t.Run("DeleteFromEmptyTable", func(t *testing.T) { 22 | t.Parallel() 23 | tc := createTestContext(t) 24 | defer tc.CleanUp(t) 25 | 26 | tc.ExecuteSQL(t, "CREATE TABLE test (id int, s text)") 27 | 28 | client := tc.Client() 29 | _, _, err := client.From("test").Delete("", "").Execute() 30 | assert.NoError(t, err) 31 | }) 32 | 33 | t.Run("DeleteFromNonEmptyTable", func(t *testing.T) { 34 | t.Parallel() 35 | tc := createTestContext(t) 36 | defer tc.CleanUp(t) 37 | 38 | tc.ExecuteSQL(t, "CREATE TABLE test (id int, s text)") 39 | tc.ExecuteSQL(t, `INSERT INTO test (id, s) VALUES (1, "a"), (1, "a"), (1, "a")`) 40 | 41 | client := tc.Client() 42 | _, _, err := client.From("test").Delete("", "").Execute() 43 | assert.NoError(t, err) 44 | 45 | res, _, err := client.From("test").Select("id", "", false). 46 | Execute() 47 | assert.NoError(t, err) 48 | 49 | var rv []map[string]interface{} 50 | tc.DecodeResult(t, res, &rv) 51 | assert.Empty(t, rv) 52 | }) 53 | 54 | t.Run("DeleteWithFilter", func(t *testing.T) { 55 | t.Parallel() 56 | tc := createTestContext(t) 57 | defer tc.CleanUp(t) 58 | 59 | tc.ExecuteSQL(t, "CREATE TABLE test (id int, s text)") 60 | tc.ExecuteSQL(t, `INSERT INTO test (id, s) VALUES (1, "a"), (2, "a"), (3, "a")`) 61 | 62 | client := tc.Client() 63 | _, _, err := client.From("test").Delete("", ""). 64 | Gt("id", "1"). 65 | Execute() 66 | assert.NoError(t, err) 67 | 68 | res, _, err := client.From("test").Select("id", "", false). 69 | Execute() 70 | assert.NoError(t, err) 71 | 72 | var rv []map[string]interface{} 73 | tc.DecodeResult(t, res, &rv) 74 | assert.Len(t, rv, 1) 75 | assert.EqualValues(t, 1, rv[0]["id"]) 76 | }) 77 | } 78 | 79 | func TestDelete_SingleTable(t *testing.T) { 80 | t.Run("in memory db", func(t *testing.T) { 81 | testDelete_SingleTable(t, createTestContextUsingInMemoryDB) 82 | }) 83 | 84 | t.Run("HMAC token auth", func(t *testing.T) { 85 | testDelete_SingleTable(t, createTestContextWithHMACTokenAuth) 86 | }) 87 | 88 | t.Run("RSA token auth", func(t *testing.T) { 89 | testDelete_SingleTable(t, createTestContextWithRSATokenAuth) 90 | }) 91 | } 92 | -------------------------------------------------------------------------------- /integration_insert_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func testInsert_SingleTable(t *testing.T, createTestContext func(t testing.TB) *TestContext) { 12 | t.Run("NoTable", func(t *testing.T) { 13 | t.Parallel() 14 | tc := createTestContext(t) 15 | defer tc.CleanUp(t) 16 | 17 | client := tc.Client() 18 | _, _, err := client.From("test"). 19 | Insert(map[string]interface{}{"id": 1}, false, "", "", ""). 20 | Execute() 21 | assert.Error(t, err) 22 | assert.Contains(t, err.Error(), "no such table: test") 23 | }) 24 | 25 | t.Run("InsertSingleValue", func(t *testing.T) { 26 | t.Parallel() 27 | tc := createTestContext(t) 28 | defer tc.CleanUp(t) 29 | 30 | tc.ExecuteSQL(t, "CREATE TABLE test (id int)") 31 | 32 | client := tc.Client() 33 | 34 | _, _, err := client.From("test"). 35 | Insert(map[string]interface{}{"id": 1}, false, "", "", ""). 36 | Execute() 37 | assert.NoError(t, err) 38 | 39 | res, _, err := client.From("test").Select("id", "", false). 40 | Execute() 41 | assert.NoError(t, err) 42 | 43 | var rv []map[string]interface{} 44 | tc.DecodeResult(t, res, &rv) 45 | assert.Len(t, rv, 1) 46 | assert.EqualValues(t, 1, rv[0]["id"]) 47 | }) 48 | 49 | t.Run("InsertValues", func(t *testing.T) { 50 | t.Parallel() 51 | tc := createTestContext(t) 52 | defer tc.CleanUp(t) 53 | 54 | tc.ExecuteSQL(t, "CREATE TABLE test (id int)") 55 | 56 | client := tc.Client() 57 | 58 | _, _, err := client.From("test"). 59 | Insert([]map[string]interface{}{{"id": 1}, {"id": 1}}, false, "", "", ""). 60 | Execute() 61 | assert.NoError(t, err) 62 | 63 | res, _, err := client.From("test").Select("id", "", false). 64 | Execute() 65 | assert.NoError(t, err) 66 | 67 | var rv []map[string]interface{} 68 | tc.DecodeResult(t, res, &rv) 69 | assert.Len(t, rv, 2) 70 | for _, row := range rv { 71 | assert.EqualValues(t, 1, row["id"]) 72 | } 73 | }) 74 | 75 | t.Run("UpsertMergeDuplicates", func(t *testing.T) { 76 | t.Parallel() 77 | tc := createTestContext(t) 78 | defer tc.CleanUp(t) 79 | 80 | tc.ExecuteSQL(t, "CREATE TABLE test (id int primary key, s text)") 81 | tc.ExecuteSQL(t, `INSERT INTO test (id, s) values (1, "a"), (2, "b")`) 82 | 83 | client := tc.Client() 84 | 85 | _, _, err := client.From("test"). 86 | Insert([]map[string]interface{}{ 87 | {"id": 1, "s": "b"}, {"id": 2, "s": "c"}, 88 | }, true, "", "", ""). 89 | Execute() 90 | assert.NoError(t, err) 91 | 92 | res, _, err := client.From("test").Select("*", "", false). 93 | Execute() 94 | assert.NoError(t, err) 95 | 96 | var rv []map[string]interface{} 97 | tc.DecodeResult(t, res, &rv) 98 | assert.Len(t, rv, 2) 99 | for idx, row := range rv { 100 | assert.EqualValues(t, idx+1, row["id"]) 101 | assert.EqualValues(t, string('a'+rune(idx+1)), row["s"]) 102 | } 103 | }) 104 | 105 | t.Run("UpsertIgnoreDuplicates", func(t *testing.T) { 106 | t.Parallel() 107 | tc := createTestContext(t) 108 | defer tc.CleanUp(t) 109 | 110 | tc.ExecuteSQL(t, "CREATE TABLE test (id int primary key, s text)") 111 | tc.ExecuteSQL(t, `INSERT INTO test (id, s) values (1, "a"), (2, "b")`) 112 | 113 | payload := bytes.NewBufferString(`[{"id": 1, "s": "b"}, {"id": 2, "s": "c"}]`) 114 | req := tc.NewRequest(t, http.MethodPost, "test", payload) 115 | req.Header.Set("Content-Type", "application/json") 116 | req.Header.Set("Prefer", "resolution=ignore-duplicates") 117 | resp := tc.ExecuteRequest(t, req) 118 | defer resp.Body.Close() 119 | 120 | client := tc.Client() 121 | res, _, err := client.From("test").Select("*", "", false). 122 | Execute() 123 | assert.NoError(t, err) 124 | 125 | var rv []map[string]interface{} 126 | tc.DecodeResult(t, res, &rv) 127 | assert.Len(t, rv, 2) 128 | for idx, row := range rv { 129 | assert.EqualValues(t, idx+1, row["id"]) 130 | assert.EqualValues(t, string('a'+rune(idx)), row["s"]) 131 | } 132 | }) 133 | 134 | t.Run("UpsertMergeDuplicatesWithOnConflicts", func(t *testing.T) { 135 | t.Parallel() 136 | tc := createTestContext(t) 137 | defer tc.CleanUp(t) 138 | 139 | tc.ExecuteSQL(t, "CREATE TABLE test (id int, s text)") 140 | tc.ExecuteSQL(t, "CREATE UNIQUE INDEX test_id on test (id)") 141 | tc.ExecuteSQL(t, `INSERT INTO test (id, s) values (1, "a"), (2, "b")`) 142 | 143 | client := tc.Client() 144 | 145 | _, _, err := client.From("test"). 146 | Insert([]map[string]interface{}{ 147 | {"id": 1, "s": "b"}, {"id": 2, "s": "c"}, 148 | }, true, "id", "", ""). 149 | Execute() 150 | assert.NoError(t, err) 151 | 152 | res, _, err := client.From("test").Select("*", "", false). 153 | Execute() 154 | assert.NoError(t, err) 155 | 156 | var rv []map[string]interface{} 157 | tc.DecodeResult(t, res, &rv) 158 | assert.Len(t, rv, 2) 159 | for idx, row := range rv { 160 | assert.EqualValues(t, idx+1, row["id"]) 161 | assert.EqualValues(t, string('a'+rune(idx+1)), row["s"]) 162 | } 163 | }) 164 | } 165 | 166 | func TestInsert_SingleTable(t *testing.T) { 167 | t.Run("in memory db", func(t *testing.T) { 168 | testInsert_SingleTable(t, createTestContextUsingInMemoryDB) 169 | }) 170 | 171 | t.Run("HMAC token auth", func(t *testing.T) { 172 | testInsert_SingleTable(t, createTestContextWithHMACTokenAuth) 173 | }) 174 | 175 | t.Run("RSA token auth", func(t *testing.T) { 176 | testDelete_SingleTable(t, createTestContextWithRSATokenAuth) 177 | }) 178 | } 179 | -------------------------------------------------------------------------------- /integration_migrate_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestMigration(t *testing.T) { 12 | t.Run("empty migrations", func(t *testing.T) { 13 | t.Parallel() 14 | tc := NewMigrationTestContext(t, nil) 15 | defer tc.CleanUp(t) 16 | 17 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 18 | defer cancel() 19 | 20 | err := tc.Migrator().Up(ctx, migrationStepAll) 21 | assert.NoError(t, err) 22 | 23 | err = tc.Migrator().Down(ctx, migrationStepAll) 24 | assert.NoError(t, err) 25 | }) 26 | 27 | t.Run("apply all migrations", func(t *testing.T) { 28 | t.Parallel() 29 | tc := NewMigrationTestContext(t, map[string]string{ 30 | "1_test.up.sql": `create table test (id int);`, 31 | "1_test.down.sql": `drop table test;`, 32 | }) 33 | defer tc.CleanUp(t) 34 | 35 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 36 | defer cancel() 37 | 38 | t.Log("up") 39 | { 40 | err := tc.Migrator().Up(ctx, migrationStepAll) 41 | assert.NoError(t, err) 42 | 43 | t.Log("rerunning migrations") 44 | err = tc.Migrator().Up(ctx, migrationStepAll) 45 | assert.NoError(t, err) 46 | } 47 | 48 | t.Log("down") 49 | { 50 | err := tc.Migrator().Down(ctx, migrationStepAll) 51 | assert.NoError(t, err) 52 | 53 | t.Log("rerunning migrations") 54 | err = tc.Migrator().Down(ctx, migrationStepAll) 55 | assert.NoError(t, err) 56 | } 57 | }) 58 | 59 | t.Run("apply migrations by step", func(t *testing.T) { 60 | t.Parallel() 61 | tc := NewMigrationTestContext(t, map[string]string{ 62 | "1_test.up.sql": `create table test (id int);`, 63 | "1_test.down.sql": `drop table test;`, 64 | "2_test2.up.sql": `create table test2 (id int);`, 65 | "2_test2.down.sql": `drop table test2;`, 66 | }) 67 | defer tc.CleanUp(t) 68 | 69 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 70 | defer cancel() 71 | 72 | t.Log("up 1 step (current step = 0)") 73 | { 74 | err := tc.Migrator().Up(ctx, 1) 75 | assert.NoError(t, err) 76 | } 77 | t.Log("up 1 step (current step = 1)") 78 | { 79 | err := tc.Migrator().Up(ctx, 1) 80 | assert.NoError(t, err) 81 | } 82 | t.Log("up 1 step (current step = 2)") 83 | { 84 | err := tc.Migrator().Up(ctx, 1) 85 | assert.Error(t, err) 86 | } 87 | 88 | t.Log("down 1 step (current step = 2)") 89 | { 90 | err := tc.Migrator().Down(ctx, 1) 91 | assert.NoError(t, err) 92 | } 93 | t.Log("down 1 step (current step = 1)") 94 | { 95 | err := tc.Migrator().Down(ctx, 1) 96 | assert.NoError(t, err) 97 | } 98 | t.Log("down 1 step (current step = 0)") 99 | { 100 | err := tc.Migrator().Down(ctx, 1) 101 | assert.Error(t, err) 102 | } 103 | }) 104 | 105 | t.Run("failed migrations", func(t *testing.T) { 106 | t.Run("up", func(t *testing.T) { 107 | t.Parallel() 108 | tc := NewMigrationTestContext(t, map[string]string{ 109 | "1_test.up.sql": `create table test invalid sql;`, 110 | "1_test.down.sql": `drop table test;`, 111 | }) 112 | defer tc.CleanUp(t) 113 | 114 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 115 | defer cancel() 116 | 117 | err := tc.Migrator().Up(ctx, migrationStepAll) 118 | assert.Error(t, err) 119 | }) 120 | 121 | t.Run("down", func(t *testing.T) { 122 | t.Parallel() 123 | tc := NewMigrationTestContext(t, map[string]string{ 124 | "1_test.up.sql": `create table test (id int);`, 125 | "1_test.down.sql": `drop table invalid sql;`, 126 | }) 127 | defer tc.CleanUp(t) 128 | 129 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 130 | defer cancel() 131 | 132 | err := tc.Migrator().Up(ctx, migrationStepAll) 133 | assert.NoError(t, err) 134 | 135 | err = tc.migrator.Down(ctx, migrationStepAll) 136 | assert.Error(t, err) 137 | }) 138 | }) 139 | } 140 | -------------------------------------------------------------------------------- /integration_security_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "net/http" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestSecurityNegativeCases(t *testing.T) { 13 | t.Run("Unauthorized", func(t *testing.T) { 14 | t.Parallel() 15 | tc := createTestContextWithHMACTokenAuth(t) 16 | defer tc.CleanUp(t) 17 | 18 | tc.authToken = "" // disable auth 19 | client := tc.Client() 20 | _, _, err := client.From("test").Select("id", "", false).Execute() 21 | assert.Error(t, err) 22 | assert.Contains(t, err.Error(), "Unauthorized") 23 | }) 24 | 25 | t.Run("TableAccessRestricted", func(t *testing.T) { 26 | t.Parallel() 27 | tc := createTestContextWithHMACTokenAuth(t) 28 | defer tc.CleanUp(t) 29 | 30 | client := tc.Client() 31 | _, _, err := client.From(tableNameMigrations).Select("id", "", false).Execute() 32 | assert.Error(t, err) 33 | assert.Contains(t, err.Error(), "Access Restricted") 34 | }) 35 | } 36 | 37 | func TestSecuritySQLInjection(t *testing.T) { 38 | t.Run("Update", func(t *testing.T) { 39 | t.Parallel() 40 | tc := createTestContextWithHMACTokenAuth(t) 41 | defer tc.CleanUp(t) 42 | 43 | tc.ExecuteSQL(t, "CREATE TABLE test (id int)") 44 | tc.ExecuteSQL(t, "insert into test values (1)") 45 | 46 | p := bytes.NewBufferString(`{"id": 2}`) 47 | req := tc.NewRequest(t, http.MethodPost, "test", p) 48 | req.Header.Set("content-type", "application/json") 49 | q := req.URL.Query() 50 | q.Set("select", "1; drop table test;select *") 51 | req.URL.RawQuery = q.Encode() 52 | 53 | resp := tc.ExecuteRequest(t, req) 54 | defer resp.Body.Close() 55 | 56 | assert.Equal(t, http.StatusCreated, resp.StatusCode) 57 | 58 | _, err := io.ReadAll(resp.Body) 59 | assert.NoError(t, err) 60 | 61 | client := tc.Client() 62 | res, _, err := client.From("test").Select("*", "", false).Execute() 63 | assert.NoError(t, err) 64 | 65 | var rv []map[string]interface{} 66 | tc.DecodeResult(t, res, &rv) 67 | assert.Len(t, rv, 2) 68 | }) 69 | 70 | t.Run("Select", func(t *testing.T) { 71 | t.Parallel() 72 | tc := createTestContextWithHMACTokenAuth(t) 73 | defer tc.CleanUp(t) 74 | 75 | tc.ExecuteSQL(t, "CREATE TABLE test (id int)") 76 | tc.ExecuteSQL(t, "insert into test values (1)") 77 | 78 | req := tc.NewRequest(t, http.MethodGet, "test", nil) 79 | req.Header.Set("content-type", "application/json") 80 | q := req.URL.Query() 81 | q.Set("select", "1; drop table test;select *") 82 | req.URL.RawQuery = q.Encode() 83 | 84 | resp := tc.ExecuteRequest(t, req) 85 | defer resp.Body.Close() 86 | 87 | assert.Equal(t, http.StatusOK, resp.StatusCode) 88 | 89 | _, err := io.ReadAll(resp.Body) 90 | assert.NoError(t, err) 91 | 92 | client := tc.Client() 93 | res, _, err := client.From("test").Select("*", "", false).Execute() 94 | assert.NoError(t, err) 95 | 96 | var rv []map[string]interface{} 97 | tc.DecodeResult(t, res, &rv) 98 | assert.Len(t, rv, 1) 99 | }) 100 | } 101 | -------------------------------------------------------------------------------- /integration_select_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/supabase/postgrest-go" 12 | ) 13 | 14 | func testSelect_SingleTable(t *testing.T, createTestContext func(t testing.TB) *TestContext) { 15 | t.Helper() 16 | 17 | t.Run("NoTable", func(t *testing.T) { 18 | t.Parallel() 19 | tc := createTestContext(t) 20 | defer tc.CleanUp(t) 21 | 22 | client := tc.Client() 23 | _, _, err := client.From("test").Select("id", "", false). 24 | Execute() 25 | assert.Error(t, err) 26 | assert.Contains(t, err.Error(), "no such table: test") 27 | }) 28 | 29 | t.Run("EmptyTable", func(t *testing.T) { 30 | t.Parallel() 31 | tc := createTestContext(t) 32 | defer tc.CleanUp(t) 33 | 34 | tc.ExecuteSQL(t, "CREATE TABLE test (id int)") 35 | 36 | client := tc.Client() 37 | res, _, err := client.From("test").Select("id", "", false). 38 | Execute() 39 | assert.NoError(t, err) 40 | 41 | var rv []map[string]interface{} 42 | tc.DecodeResult(t, res, &rv) 43 | assert.Empty(t, rv) 44 | }) 45 | 46 | t.Run("SelectAllColumns", func(t *testing.T) { 47 | t.Parallel() 48 | tc := createTestContext(t) 49 | defer tc.CleanUp(t) 50 | 51 | tc.ExecuteSQL(t, "CREATE TABLE test (id int, s text)") 52 | tc.ExecuteSQL(t, `INSERT INTO test (id, s) VALUES (1, "a"), (1, "a"), (1, "a")`) 53 | 54 | client := tc.Client() 55 | res, _, err := client.From("test").Select("*", "", false). 56 | Execute() 57 | assert.NoError(t, err) 58 | 59 | var rv []map[string]interface{} 60 | tc.DecodeResult(t, res, &rv) 61 | assert.Len(t, rv, 3) 62 | for _, row := range rv { 63 | assert.EqualValues(t, 1, row["id"]) 64 | assert.EqualValues(t, "a", row["s"]) 65 | } 66 | }) 67 | 68 | t.Run("SelectSingleColumn", func(t *testing.T) { 69 | t.Parallel() 70 | tc := createTestContext(t) 71 | defer tc.CleanUp(t) 72 | 73 | tc.ExecuteSQL(t, "CREATE TABLE test (id int, s text)") 74 | tc.ExecuteSQL(t, `INSERT INTO test (id, s) VALUES (1, "a"), (1, "a"), (1, "a")`) 75 | 76 | client := tc.Client() 77 | res, _, err := client.From("test").Select("id", "", false). 78 | Execute() 79 | assert.NoError(t, err) 80 | 81 | var rv []map[string]interface{} 82 | tc.DecodeResult(t, res, &rv) 83 | assert.Len(t, rv, 3) 84 | for _, row := range rv { 85 | assert.EqualValues(t, 1, row["id"]) 86 | } 87 | }) 88 | 89 | t.Run("SelectWithFilter", func(t *testing.T) { 90 | t.Parallel() 91 | tc := createTestContext(t) 92 | defer tc.CleanUp(t) 93 | 94 | tc.ExecuteSQL(t, "CREATE TABLE test (id int)") 95 | tc.ExecuteSQL(t, `INSERT INTO test (id) VALUES (1), (2), (3)`) 96 | 97 | client := tc.Client() 98 | res, _, err := client.From("test").Select("id", "", false). 99 | Eq("id", "1"). 100 | Execute() 101 | assert.NoError(t, err) 102 | 103 | var rv []map[string]interface{} 104 | tc.DecodeResult(t, res, &rv) 105 | assert.Len(t, rv, 1) 106 | assert.EqualValues(t, 1, rv[0]["id"]) 107 | }) 108 | 109 | t.Run("SelectWithOrder", func(t *testing.T) { 110 | t.Parallel() 111 | tc := createTestContext(t) 112 | defer tc.CleanUp(t) 113 | 114 | tc.ExecuteSQL(t, "CREATE TABLE test (id int, s text)") 115 | tc.ExecuteSQL(t, `INSERT INTO test (id, s) VALUES (1, "a"), (2, "b"), (3, "b")`) 116 | 117 | client := tc.Client() 118 | 119 | { 120 | res, _, err := client.From("test").Select("*", "", false). 121 | Order("id", &postgrest.OrderOpts{ 122 | Ascending: true, 123 | }). 124 | Execute() 125 | assert.NoError(t, err) 126 | 127 | var rv []map[string]interface{} 128 | tc.DecodeResult(t, res, &rv) 129 | assert.Len(t, rv, 3) 130 | assert.EqualValues(t, 1, rv[0]["id"]) 131 | assert.EqualValues(t, 2, rv[1]["id"]) 132 | assert.EqualValues(t, 3, rv[2]["id"]) 133 | } 134 | 135 | { 136 | res, _, err := client.From("test").Select("*", "", false). 137 | Order("id", &postgrest.OrderOpts{ 138 | Ascending: false, 139 | }). 140 | Execute() 141 | assert.NoError(t, err) 142 | 143 | var rv []map[string]interface{} 144 | tc.DecodeResult(t, res, &rv) 145 | assert.Len(t, rv, 3) 146 | assert.EqualValues(t, 3, rv[0]["id"]) 147 | assert.EqualValues(t, 2, rv[1]["id"]) 148 | assert.EqualValues(t, 1, rv[2]["id"]) 149 | } 150 | 151 | { 152 | res, _, err := client.From("test").Select("*", "", false). 153 | Order("s", &postgrest.OrderOpts{ 154 | Ascending: true, 155 | }). 156 | Order("id", &postgrest.OrderOpts{ 157 | Ascending: false, 158 | }). 159 | Execute() 160 | assert.NoError(t, err) 161 | 162 | var rv []map[string]interface{} 163 | tc.DecodeResult(t, res, &rv) 164 | assert.Len(t, rv, 3) 165 | assert.EqualValues(t, 1, rv[0]["id"]) 166 | assert.EqualValues(t, 3, rv[1]["id"]) 167 | assert.EqualValues(t, 2, rv[2]["id"]) 168 | } 169 | }) 170 | 171 | t.Run("SelectPagination", func(t *testing.T) { 172 | t.Parallel() 173 | const rowsCount = int64(10) 174 | 175 | tc := createTestContext(t) 176 | defer tc.CleanUp(t) 177 | 178 | tc.ExecuteSQL(t, "CREATE TABLE test (id int)") 179 | var ps []string 180 | for i := int64(0); i < rowsCount; i++ { 181 | ps = append(ps, fmt.Sprintf("(%d)", i+1)) 182 | } 183 | tc.ExecuteSQL(t, fmt.Sprintf(`INSERT INTO test (id) VALUES %s`, strings.Join(ps, ", "))) 184 | 185 | client := tc.Client() 186 | 187 | { 188 | res, _, err := client.From("test").Select("*", "", false). 189 | Limit(3, ""). 190 | Order("id", &postgrest.OrderOpts{Ascending: true}). 191 | Execute() 192 | assert.NoError(t, err) 193 | 194 | var rv []map[string]interface{} 195 | tc.DecodeResult(t, res, &rv) 196 | assert.Len(t, rv, 3) 197 | assert.EqualValues(t, 1, rv[0]["id"]) 198 | assert.EqualValues(t, 2, rv[1]["id"]) 199 | assert.EqualValues(t, 3, rv[2]["id"]) 200 | } 201 | 202 | { 203 | res, _, err := client.From("test").Select("*", "", false). 204 | Range(3, 5, ""). 205 | Order("id", &postgrest.OrderOpts{Ascending: true}). 206 | Execute() 207 | assert.NoError(t, err) 208 | 209 | var rv []map[string]interface{} 210 | tc.DecodeResult(t, res, &rv) 211 | assert.Len(t, rv, 3) 212 | assert.EqualValues(t, 4, rv[0]["id"]) 213 | assert.EqualValues(t, 5, rv[1]["id"]) 214 | assert.EqualValues(t, 6, rv[2]["id"]) 215 | } 216 | 217 | { 218 | res, count, err := client.From("test").Select("*", "exact", false). 219 | Range(3, 5, ""). 220 | Order("id", &postgrest.OrderOpts{Ascending: true}). 221 | Execute() 222 | assert.NoError(t, err) 223 | 224 | var rv []map[string]interface{} 225 | tc.DecodeResult(t, res, &rv) 226 | assert.Len(t, rv, 3) 227 | assert.EqualValues(t, 4, rv[0]["id"]) 228 | assert.EqualValues(t, 5, rv[1]["id"]) 229 | assert.EqualValues(t, 6, rv[2]["id"]) 230 | 231 | assert.Equal(t, rowsCount, count) 232 | } 233 | 234 | { 235 | req := tc.NewRequest(t, http.MethodGet, "test", nil) 236 | req.Header.Set("Range", "3-5") 237 | resp := tc.ExecuteRequest(t, req) 238 | defer resp.Body.Close() 239 | 240 | res, err := io.ReadAll(resp.Body) 241 | assert.NoError(t, err) 242 | var rv []map[string]interface{} 243 | tc.DecodeResult(t, res, &rv) 244 | assert.Len(t, rv, 3) 245 | assert.EqualValues(t, 4, rv[0]["id"]) 246 | assert.EqualValues(t, 5, rv[1]["id"]) 247 | assert.EqualValues(t, 6, rv[2]["id"]) 248 | assert.Equal(t, resp.Header.Get("Content-Range"), "3-5/*") 249 | } 250 | 251 | { 252 | req := tc.NewRequest(t, http.MethodGet, "test", nil) 253 | req.Header.Set("Range", "7-") 254 | resp := tc.ExecuteRequest(t, req) 255 | defer resp.Body.Close() 256 | 257 | res, err := io.ReadAll(resp.Body) 258 | assert.NoError(t, err) 259 | var rv []map[string]interface{} 260 | tc.DecodeResult(t, res, &rv) 261 | assert.Len(t, rv, 3) 262 | assert.EqualValues(t, 8, rv[0]["id"]) 263 | assert.EqualValues(t, 9, rv[1]["id"]) 264 | assert.EqualValues(t, 10, rv[2]["id"]) 265 | assert.Equal(t, resp.Header.Get("Content-Range"), "7-/*") 266 | } 267 | }) 268 | 269 | t.Run("SelectView", func(t *testing.T) { 270 | t.Parallel() 271 | tc := createTestContext(t) 272 | defer tc.CleanUp(t) 273 | 274 | tc.ExecuteSQL(t, "CREATE TABLE test (id int)") 275 | tc.ExecuteSQL(t, `INSERT INTO test (id) VALUES (1), (1), (1)`) 276 | tc.ExecuteSQL(t, "CREATE VIEW test_view (id) AS SELECT id + 1 FROM test") 277 | 278 | client := tc.Client() 279 | res, _, err := client.From("test_view").Select("id", "", false). 280 | Execute() 281 | assert.NoError(t, err) 282 | 283 | var rv []map[string]interface{} 284 | tc.DecodeResult(t, res, &rv) 285 | assert.Len(t, rv, 3) 286 | for _, row := range rv { 287 | assert.EqualValues(t, 2, row["id"]) 288 | } 289 | }) 290 | 291 | t.Run("SelectOperator", func(t *testing.T) { 292 | t.Parallel() 293 | tc := createTestContext(t) 294 | defer tc.CleanUp(t) 295 | 296 | tc.ExecuteSQL(t, "CREATE TABLE test (id int, s text, v int nullable)") 297 | tc.ExecuteSQL(t, `INSERT INTO test (id, s, v) VALUES (1, "a", null), (2, "b", null), (3, "c", 1)`) 298 | 299 | client := tc.Client() 300 | 301 | cases := []struct { 302 | qb func(q *postgrest.QueryBuilder) *postgrest.FilterBuilder 303 | expected []map[string]interface{} 304 | }{ 305 | { 306 | qb: func(q *postgrest.QueryBuilder) *postgrest.FilterBuilder { 307 | return q.Select("id", "", false). 308 | Eq("id", "1"). 309 | Eq("s", "a") 310 | }, 311 | expected: []map[string]interface{}{{"id": 1}}, 312 | }, 313 | { 314 | qb: func(q *postgrest.QueryBuilder) *postgrest.FilterBuilder { 315 | return q.Select("id", "", false). 316 | Lt("id", "1") 317 | }, 318 | expected: []map[string]interface{}{}, 319 | }, 320 | { 321 | qb: func(q *postgrest.QueryBuilder) *postgrest.FilterBuilder { 322 | return q.Select("id", "", false). 323 | Neq("id", "1"). 324 | Lt("s", "c") 325 | }, 326 | expected: []map[string]interface{}{{"id": 2}}, 327 | }, 328 | { 329 | qb: func(q *postgrest.QueryBuilder) *postgrest.FilterBuilder { 330 | return q.Select("id", "", false). 331 | Neq("id", "1"). 332 | Like("s", "c") 333 | }, 334 | expected: []map[string]interface{}{{"id": 3}}, 335 | }, 336 | { 337 | qb: func(q *postgrest.QueryBuilder) *postgrest.FilterBuilder { 338 | return q.Select("id", "", false). 339 | In("id", []string{"1", "2", "4", "10000"}). 340 | Order("id", &postgrest.OrderOpts{Ascending: true}) 341 | }, 342 | expected: []map[string]interface{}{{"id": 1}, {"id": 2}}, 343 | }, 344 | { 345 | qb: func(q *postgrest.QueryBuilder) *postgrest.FilterBuilder { 346 | return q.Select("id", "", false). 347 | Is("v", "null"). 348 | Order("id", &postgrest.OrderOpts{Ascending: true}) 349 | }, 350 | expected: []map[string]interface{}{{"id": 1}, {"id": 2}}, 351 | }, 352 | { 353 | qb: func(q *postgrest.QueryBuilder) *postgrest.FilterBuilder { 354 | return q.Select("id", "", false). 355 | Is("v", "true"). 356 | Order("id", &postgrest.OrderOpts{Ascending: true}) 357 | }, 358 | expected: []map[string]interface{}{{"id": 3}}, 359 | }, 360 | { 361 | qb: func(q *postgrest.QueryBuilder) *postgrest.FilterBuilder { 362 | return q.Select("id", "", false). 363 | Not("id", "eq", "1"). 364 | Order("id", &postgrest.OrderOpts{Ascending: true}) 365 | }, 366 | expected: []map[string]interface{}{{"id": 2}, {"id": 3}}, 367 | }, 368 | { 369 | qb: func(q *postgrest.QueryBuilder) *postgrest.FilterBuilder { 370 | return q.Select("id", "", false). 371 | Or("id.eq.1,id.eq.3", ""). 372 | Order("id", &postgrest.OrderOpts{Ascending: true}) 373 | }, 374 | expected: []map[string]interface{}{{"id": 1}, {"id": 3}}, 375 | }, 376 | { 377 | qb: func(q *postgrest.QueryBuilder) *postgrest.FilterBuilder { 378 | return q.Select("id", "", false). 379 | Or("id.eq.1,s.like.中文", ""). 380 | Order("id", &postgrest.OrderOpts{Ascending: true}) 381 | }, 382 | expected: []map[string]interface{}{{"id": 1}}, 383 | }, 384 | } 385 | 386 | for idx := range cases { 387 | t.Run(fmt.Sprintf("case #%d", idx), func(t *testing.T) { 388 | c := cases[idx] 389 | res, _, err := c.qb(client.From("test")).Execute() 390 | assert.NoError(t, err) 391 | 392 | var rv []map[string]interface{} 393 | tc.DecodeResult(t, res, &rv) 394 | assert.Equal(t, len(c.expected), len(rv)) 395 | for idx, row := range rv { 396 | expected := c.expected[idx] 397 | assert.Equal(t, len(expected), len(row)) 398 | for k, v := range expected { 399 | assert.EqualValues(t, v, row[k]) 400 | } 401 | } 402 | }) 403 | } 404 | }) 405 | 406 | t.Run("SelectWithAdaptingColumns", func(t *testing.T) { 407 | t.Parallel() 408 | tc := createTestContext(t) 409 | defer tc.CleanUp(t) 410 | 411 | tc.ExecuteSQL(t, "CREATE TABLE test (id int, s text, d text)") 412 | tc.ExecuteSQL(t, `INSERT INTO test (id, s, d) VALUES (1, "1", "a"), (2, "2", "a"), (3, "3", "a")`) 413 | 414 | client := tc.Client() 415 | res, _, err := client.From("test").Select("id_str:id::text, s::int, d_text:d", "", false). 416 | Execute() 417 | assert.NoError(t, err) 418 | 419 | var rv []map[string]interface{} 420 | tc.DecodeResult(t, res, &rv) 421 | assert.Len(t, rv, 3) 422 | for idx, row := range rv { 423 | assert.EqualValues(t, fmt.Sprint(idx+1), row["id_str"]) 424 | assert.EqualValues(t, idx+1, row["s"]) 425 | assert.EqualValues(t, "a", row["d_text"]) 426 | } 427 | }) 428 | } 429 | 430 | func TestSelect_SingleTable(t *testing.T) { 431 | t.Run("in memory db", func(t *testing.T) { 432 | testSelect_SingleTable(t, createTestContextUsingInMemoryDB) 433 | }) 434 | 435 | t.Run("HMAC token auth", func(t *testing.T) { 436 | testSelect_SingleTable(t, createTestContextWithHMACTokenAuth) 437 | }) 438 | 439 | t.Run("RSA token auth", func(t *testing.T) { 440 | testSelect_SingleTable(t, createTestContextWithRSATokenAuth) 441 | }) 442 | } 443 | -------------------------------------------------------------------------------- /integration_update_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func testUpdate_SingleTable(t *testing.T, createTestContext func(t testing.TB) *TestContext) { 12 | t.Run("NoTable", func(t *testing.T) { 13 | t.Parallel() 14 | tc := createTestContext(t) 15 | defer tc.CleanUp(t) 16 | 17 | client := tc.Client() 18 | _, _, err := client.From("test").Update(map[string]interface{}{"id": 1}, "", "1"). 19 | Execute() 20 | assert.Error(t, err) 21 | assert.Contains(t, err.Error(), "no such table: test") 22 | }) 23 | 24 | t.Run("UpdateRecords", func(t *testing.T) { 25 | t.Parallel() 26 | tc := createTestContext(t) 27 | defer tc.CleanUp(t) 28 | 29 | tc.ExecuteSQL(t, "CREATE TABLE test (id int, s text)") 30 | tc.ExecuteSQL(t, `INSERT INTO test (id, s) VALUES (1, "a"), (1, "a"), (1, "a")`) 31 | 32 | client := tc.Client() 33 | _, _, err := client.From("test").Update(map[string]interface{}{"id": 2}, "", "3"). 34 | Execute() 35 | assert.NoError(t, err) 36 | 37 | res, _, err := client.From("test").Select("id", "", false). 38 | Execute() 39 | assert.NoError(t, err) 40 | 41 | var rv []map[string]interface{} 42 | tc.DecodeResult(t, res, &rv) 43 | assert.Len(t, rv, 3) 44 | for _, row := range rv { 45 | assert.EqualValues(t, 2, row["id"]) 46 | } 47 | }) 48 | 49 | t.Run("UpdateWithFilter", func(t *testing.T) { 50 | t.Parallel() 51 | tc := createTestContext(t) 52 | defer tc.CleanUp(t) 53 | 54 | tc.ExecuteSQL(t, "CREATE TABLE test (id int, s text)") 55 | tc.ExecuteSQL(t, `INSERT INTO test (id, s) VALUES (1, "a"), (1, "a"), (1, "a")`) 56 | 57 | client := tc.Client() 58 | _, _, err := client.From("test"). 59 | Update(map[string]interface{}{"id": 2}, "", "3"). 60 | Eq("id", "100"). 61 | Execute() 62 | assert.NoError(t, err) 63 | 64 | res, _, err := client.From("test").Select("id", "", false). 65 | Execute() 66 | assert.NoError(t, err) 67 | 68 | var rv []map[string]interface{} 69 | tc.DecodeResult(t, res, &rv) 70 | assert.Len(t, rv, 3) 71 | for _, row := range rv { 72 | assert.EqualValues(t, 1, row["id"]) 73 | } 74 | }) 75 | 76 | t.Run("UpdateSingleEntry", func(t *testing.T) { 77 | t.Parallel() 78 | tc := createTestContext(t) 79 | defer tc.CleanUp(t) 80 | 81 | tc.ExecuteSQL(t, "CREATE TABLE test (id int, s text)") 82 | tc.ExecuteSQL(t, `INSERT INTO test (id, s) VALUES (1, "a"), (2, "c")`) 83 | 84 | b := bytes.NewBufferString(`{"id": 1, "s": "b"}`) 85 | req := tc.NewRequest(t, http.MethodPut, "test", b) 86 | req.Header.Set("Content-Type", "application/json") 87 | q := req.URL.Query() 88 | q.Set("id", "eq.1") 89 | req.URL.RawQuery = q.Encode() 90 | 91 | resp := tc.ExecuteRequest(t, req) 92 | defer resp.Body.Close() 93 | 94 | client := tc.Client() 95 | res, _, err := client.From("test").Select("*", "", false). 96 | Execute() 97 | assert.NoError(t, err) 98 | 99 | var rv []map[string]interface{} 100 | tc.DecodeResult(t, res, &rv) 101 | assert.Len(t, rv, 2) 102 | for idx, row := range rv { 103 | assert.EqualValues(t, idx+1, row["id"]) 104 | assert.EqualValues(t, string('b'+rune(idx)), row["s"]) 105 | } 106 | }) 107 | } 108 | 109 | func TestUpdate_SingleTable(t *testing.T) { 110 | t.Run("in memory db", func(t *testing.T) { 111 | testUpdate_SingleTable(t, createTestContextUsingInMemoryDB) 112 | }) 113 | 114 | t.Run("HMAC token auth", func(t *testing.T) { 115 | testUpdate_SingleTable(t, createTestContextWithHMACTokenAuth) 116 | }) 117 | 118 | t.Run("RSA token auth", func(t *testing.T) { 119 | testUpdate_SingleTable(t, createTestContextWithRSATokenAuth) 120 | }) 121 | } 122 | -------------------------------------------------------------------------------- /logger.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/go-logr/logr" 7 | "github.com/go-logr/zapr" 8 | "github.com/spf13/cobra" 9 | "go.uber.org/zap" 10 | "go.uber.org/zap/zapcore" 11 | ) 12 | 13 | var setupLogger logr.Logger = logr.Discard() 14 | 15 | func init() { 16 | zapConfig := zap.NewDevelopmentConfig() 17 | zapConfig.Level = zap.NewAtomicLevelAt(zapcore.Level(-12)) 18 | zapLog, err := zapConfig.Build() 19 | if err != nil { 20 | panic(err) 21 | } 22 | 23 | setupLogger = zapr.NewLogger(zapLog).WithName("setup") 24 | } 25 | 26 | func createLogger(cmd *cobra.Command) (logr.Logger, error) { 27 | logLevel, err := cmd.Flags().GetInt8(cliFlagLogLevel) 28 | if err != nil { 29 | return logr.Discard(), fmt.Errorf("read %s: %w", cliFlagLogLevel, err) 30 | } 31 | logDevel, err := cmd.Flags().GetBool(cliFlagLogDevel) 32 | if err != nil { 33 | return logr.Discard(), fmt.Errorf("read %s: %w", cliFlagLogDevel, err) 34 | } 35 | 36 | var zapConfig zap.Config 37 | if logDevel { 38 | zapConfig = zap.NewDevelopmentConfig() 39 | } else { 40 | zapConfig = zap.NewProductionConfig() 41 | } 42 | zapConfig.Level = zap.NewAtomicLevelAt(zapcore.Level(-logLevel)) 43 | zapLog, err := zapConfig.Build() 44 | if err != nil { 45 | return logr.Discard(), fmt.Errorf("create logger: %w", err) 46 | } 47 | 48 | return zapr.NewLogger(zapLog), nil 49 | } 50 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/spf13/cobra" 7 | "github.com/spf13/pflag" 8 | ) 9 | 10 | const ( 11 | cliFlagDBDSN = "db-dsn" 12 | cliFlagLogLevel = "log-level" 13 | cliFlagLogDevel = "log-devel" 14 | ) 15 | 16 | func bindDBDSNFlag(fs *pflag.FlagSet) { 17 | fs.String(cliFlagDBDSN, "", "Database data source name to use.") 18 | } 19 | 20 | func createMainCmd() *cobra.Command { 21 | cmd := &cobra.Command{ 22 | Use: "sqlite-rest", 23 | Short: "Serve a RESTful API from a SQLite database", 24 | SilenceUsage: true, 25 | } 26 | 27 | cmd.PersistentFlags(). 28 | Int8(cliFlagLogLevel, 5, "Log level to use. Use 8 or more for verbose log.") 29 | cmd.PersistentFlags(). 30 | Bool(cliFlagLogDevel, false, "Enable devel log format?") 31 | 32 | cmd.AddCommand( 33 | createServeCmd(), 34 | createMigrateCmd(), 35 | ) 36 | 37 | cmd.CompletionOptions.DisableDefaultCmd = true 38 | 39 | return cmd 40 | } 41 | 42 | func main() { 43 | cmd := createMainCmd() 44 | 45 | if cmd.Execute() != nil { 46 | os.Exit(1) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /metrics.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "net/http/pprof" 8 | "time" 9 | 10 | "github.com/go-chi/chi/v5" 11 | "github.com/go-chi/chi/v5/middleware" 12 | "github.com/go-logr/logr" 13 | "github.com/jmoiron/sqlx" 14 | "github.com/prometheus/client_golang/prometheus" 15 | "github.com/prometheus/client_golang/prometheus/promauto" 16 | "github.com/prometheus/client_golang/prometheus/promhttp" 17 | "github.com/spf13/pflag" 18 | ) 19 | 20 | func init() { 21 | // NOTE: this is to remove the registered pprof handlers from net/http/pprof init call. 22 | // The pprof handlers will be registered only if the pprof server is enabled. 23 | http.DefaultServeMux = http.NewServeMux() 24 | } 25 | 26 | const metricsServerDisabledAddr = "" 27 | const pprofServerDisabledAddr = "" 28 | 29 | type MetricsServerOptions struct { 30 | Logger logr.Logger 31 | Addr string 32 | Queryer sqlx.QueryerContext 33 | } 34 | 35 | func (opts *MetricsServerOptions) bindCLIFlags(fs *pflag.FlagSet) { 36 | fs.StringVar( 37 | &opts.Addr, "metrics-addr", ":8081", 38 | "metrics server listen address. Empty value means disabled.", 39 | ) 40 | } 41 | 42 | func (opts *MetricsServerOptions) defaults() error { 43 | if opts.Logger.GetSink() == nil { 44 | opts.Logger = logr.Discard() 45 | } 46 | 47 | if opts.Addr != metricsServerDisabledAddr { 48 | if opts.Queryer == nil { 49 | return fmt.Errorf(".Queryer is required") 50 | } 51 | } 52 | 53 | return nil 54 | } 55 | 56 | type metricsServer struct { 57 | logger logr.Logger 58 | server *http.Server 59 | queryer sqlx.QueryerContext 60 | } 61 | 62 | func NewMetricsServer(opts MetricsServerOptions) (*metricsServer, error) { 63 | if err := opts.defaults(); err != nil { 64 | return nil, err 65 | } 66 | 67 | srv := &metricsServer{ 68 | logger: opts.Logger, 69 | queryer: opts.Queryer, 70 | } 71 | 72 | if opts.Addr == metricsServerDisabledAddr { 73 | return srv, nil 74 | } 75 | 76 | serverMux := http.NewServeMux() 77 | serverMux.Handle("/metrics", promhttp.Handler()) 78 | srv.server = &http.Server{ 79 | Addr: opts.Addr, 80 | Handler: serverMux, 81 | } 82 | 83 | return srv, nil 84 | } 85 | 86 | func (server *metricsServer) monitorDatabaseSize( 87 | done <-chan struct{}, 88 | observeFn func(sizeInBytes float64), 89 | ) { 90 | const dbSizeQuery = `SELECT 91 | page_count * page_size 92 | FROM pragma_page_count(), pragma_page_size();` 93 | 94 | observe := func() { 95 | var size int64 96 | err := server.queryer.QueryRowxContext(context.Background(), dbSizeQuery).Scan(&size) 97 | if err != nil { 98 | server.logger.Error(err, "failed to get database size") 99 | return 100 | } 101 | 102 | observeFn(float64(size)) 103 | } 104 | observe() 105 | 106 | ticker := time.NewTicker(30 * time.Second) 107 | defer ticker.Stop() 108 | for { 109 | select { 110 | case <-done: 111 | return 112 | case <-ticker.C: 113 | observe() 114 | } 115 | } 116 | } 117 | 118 | func (server *metricsServer) Start(done <-chan struct{}) { 119 | if server.server == nil { 120 | server.logger.V(8).Info("metrics server is disabled") 121 | return 122 | } 123 | 124 | go server.monitorDatabaseSize(done, func(sizeInBytes float64) { 125 | metricsDatabaseSize.Set(sizeInBytes) 126 | server.logger.V(8).Info("database size", "sizeInBytes", sizeInBytes) 127 | }) 128 | go server.server.ListenAndServe() 129 | 130 | server.logger.Info("metrics server started", "addr", server.server.Addr) 131 | <-done 132 | 133 | server.logger.Info("shutting metrics server") 134 | shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 135 | defer cancel() 136 | server.server.Shutdown(shutdownCtx) 137 | } 138 | 139 | const ( 140 | metricsNamespace = "sqlite_rest" 141 | metricsLabelTarget = "target" // name of the table/view 142 | metricsLabelTargetOperation = "operation" // name of the operation 143 | metricsLabelHTTPCode = "http_code" // HTTP response code 144 | ) 145 | 146 | var ( 147 | metricsAuthFailedRequestsTotal = promauto.NewCounter( 148 | prometheus.CounterOpts{ 149 | Namespace: metricsNamespace, 150 | Name: "auth_failed_requests_total", 151 | Help: "Total number of failed authentication requests", 152 | }, 153 | ) 154 | 155 | metricsAccessCheckFailedRequestsTotal = promauto.NewCounter( 156 | prometheus.CounterOpts{ 157 | Namespace: metricsNamespace, 158 | Name: "access_check_failed_requests_total", 159 | Help: "Total number of failed access check requests", 160 | }, 161 | ) 162 | 163 | metricsRequestTotal = promauto.NewCounterVec( 164 | prometheus.CounterOpts{ 165 | Namespace: metricsNamespace, 166 | Name: "http_requests_total", 167 | Help: "Total number of HTTP requests", 168 | }, 169 | []string{metricsLabelTarget, metricsLabelTargetOperation, metricsLabelHTTPCode}, 170 | ) 171 | 172 | metricsRequestLatency = promauto.NewHistogramVec( 173 | prometheus.HistogramOpts{ 174 | Namespace: metricsNamespace, 175 | Name: "http_request_duration_milliseconds", 176 | Help: "HTTP request latency", 177 | Buckets: []float64{1, 10, 100, 500, 1000}, 178 | }, 179 | []string{metricsLabelTarget, metricsLabelTargetOperation, metricsLabelHTTPCode}, 180 | ) 181 | 182 | metricsDatabaseSize = promauto.NewGauge( 183 | prometheus.GaugeOpts{ 184 | Namespace: metricsNamespace, 185 | Name: "database_size_bytes", 186 | Help: "Size of the database file", 187 | }, 188 | ) 189 | ) 190 | 191 | func recordRequestMetrics(op string) func(http.Handler) http.Handler { 192 | return func(next http.Handler) http.Handler { 193 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 194 | start := time.Now() 195 | ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) 196 | 197 | defer func() { 198 | httpCode := fmt.Sprint(ww.Status()) 199 | target := chi.URLParam(r, routeVarTableOrView) 200 | metricsRequestTotal. 201 | WithLabelValues(target, op, httpCode). 202 | Inc() 203 | metricsRequestLatency. 204 | WithLabelValues(target, op, httpCode). 205 | Observe(float64(time.Since(start).Milliseconds())) 206 | }() 207 | 208 | next.ServeHTTP(ww, r) 209 | }) 210 | } 211 | } 212 | 213 | type PprofServerOptions struct { 214 | Logger logr.Logger 215 | Addr string 216 | } 217 | 218 | func (opts *PprofServerOptions) bindCLIFlags(fs *pflag.FlagSet) { 219 | fs.StringVar( 220 | &opts.Addr, "pprof-addr", pprofServerDisabledAddr, 221 | "pprof server listen address. Empty value means disabled.", 222 | ) 223 | } 224 | 225 | func (opts *PprofServerOptions) defaults() error { 226 | if opts.Logger.GetSink() == nil { 227 | opts.Logger = logr.Discard() 228 | } 229 | 230 | return nil 231 | } 232 | 233 | type pprofServer struct { 234 | logger logr.Logger 235 | server *http.Server 236 | } 237 | 238 | func NewPprofServer(opts PprofServerOptions) (*pprofServer, error) { 239 | if err := opts.defaults(); err != nil { 240 | return nil, err 241 | } 242 | 243 | srv := &pprofServer{ 244 | logger: opts.Logger, 245 | } 246 | 247 | if opts.Addr == pprofServerDisabledAddr { 248 | return srv, nil 249 | } 250 | 251 | serverMux := http.NewServeMux() 252 | serverMux.HandleFunc("/debug/pprof/", pprof.Index) 253 | serverMux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) 254 | serverMux.HandleFunc("/debug/pprof/profile", pprof.Profile) 255 | serverMux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) 256 | serverMux.HandleFunc("/debug/pprof/trace", pprof.Trace) 257 | 258 | srv.server = &http.Server{ 259 | Addr: opts.Addr, 260 | Handler: serverMux, 261 | } 262 | 263 | return srv, nil 264 | } 265 | 266 | func (server *pprofServer) Start(done <-chan struct{}) { 267 | if server.server == nil { 268 | return 269 | } 270 | 271 | server.logger.Info("pprof server is enabled, make sure it's not exposed to the public internet") 272 | 273 | go server.server.ListenAndServe() 274 | 275 | server.logger.Info("pprof server started", "addr", server.server.Addr) 276 | <-done 277 | 278 | server.logger.Info("shutting pprof server") 279 | shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 280 | defer cancel() 281 | server.server.Shutdown(shutdownCtx) 282 | } 283 | -------------------------------------------------------------------------------- /metrics_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestMetricsServer_monitorDatabaseSize(t *testing.T) { 11 | t.Parallel() 12 | 13 | tc := createTestContextWithHMACTokenAuth(t) 14 | defer tc.CleanUp(t) 15 | 16 | tc.ExecuteSQL(t, "CREATE TABLE test (id int, s text)") 17 | tc.ExecuteSQL(t, `INSERT INTO test (id, s) VALUES (1, "a"), (1, "a"), (1, "a")`) 18 | 19 | metricsServer, err := NewMetricsServer(MetricsServerOptions{ 20 | Logger: createTestLogger(t).WithName("test"), 21 | Addr: ":8081", 22 | Queryer: tc.DB(), 23 | }) 24 | assert.NoError(t, err) 25 | 26 | done := make(chan struct{}) 27 | observeFinish := make(chan struct{}) 28 | 29 | go metricsServer.monitorDatabaseSize(done, func(sizeInBytes float64) { 30 | close(observeFinish) 31 | 32 | assert.True(t, sizeInBytes > 0) 33 | }) 34 | 35 | time.Sleep(100 * time.Millisecond) 36 | close(done) 37 | <-observeFinish 38 | } 39 | -------------------------------------------------------------------------------- /migrate.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | "fmt" 8 | "io/fs" 9 | "os" 10 | "path/filepath" 11 | 12 | "github.com/go-logr/logr" 13 | "github.com/golang-migrate/migrate/v4" 14 | "github.com/golang-migrate/migrate/v4/database/sqlite3" 15 | _ "github.com/golang-migrate/migrate/v4/source/file" 16 | "github.com/spf13/cobra" 17 | ) 18 | 19 | const ( 20 | tableNameMigrations = "__sqlite_rest_migrations" 21 | 22 | migrationDirectionUp = "up" 23 | migrationDirectionDown = "down" 24 | 25 | migrationStepAll = -1 26 | ) 27 | 28 | func isApplyAllStep(step int) bool { 29 | return step <= 0 30 | } 31 | 32 | func createMigrateCmd() *cobra.Command { 33 | var ( 34 | flagDirection string 35 | flagStep int 36 | ) 37 | 38 | cmd := &cobra.Command{ 39 | Use: "migrate migrations-dir", 40 | Short: "Apply database migrations", 41 | SilenceUsage: true, 42 | Args: cobra.ExactArgs(1), 43 | RunE: func(cmd *cobra.Command, args []string) error { 44 | logger, err := createLogger(cmd) 45 | if err != nil { 46 | setupLogger.Error(err, "failed to create logger") 47 | return err 48 | } 49 | 50 | db, err := openDB(cmd) 51 | if err != nil { 52 | setupLogger.Error(err, "create db") 53 | return err 54 | } 55 | defer db.Close() 56 | 57 | opts := &MigrateOptions{ 58 | Logger: logger, 59 | DB: db.DB, 60 | SourceDIR: args[0], 61 | } 62 | migrator, err := NewMigrator(opts) 63 | if err != nil { 64 | setupLogger.Error(err, "failed to create migrator") 65 | return err 66 | } 67 | 68 | ctx, cancel := context.WithCancel(context.Background()) 69 | defer cancel() 70 | 71 | var migrateErr error 72 | switch flagDirection { 73 | case migrationDirectionUp: 74 | migrateErr = migrator.Up(ctx, flagStep) 75 | case migrationDirectionDown: 76 | migrateErr = migrator.Down(ctx, flagStep) 77 | default: 78 | // defaults to up 79 | migrateErr = migrator.Up(ctx, flagStep) 80 | } 81 | if migrateErr != nil { 82 | return migrateErr 83 | } 84 | 85 | return nil 86 | }, 87 | } 88 | 89 | bindDBDSNFlag(cmd.Flags()) 90 | 91 | return cmd 92 | } 93 | 94 | type MigrateOptions struct { 95 | Logger logr.Logger 96 | DB *sql.DB 97 | SourceDIR string 98 | } 99 | 100 | func (opts *MigrateOptions) defaults() error { 101 | if opts.Logger.GetSink() == nil { 102 | opts.Logger = logr.Discard() 103 | } 104 | 105 | if opts.DB == nil { 106 | return fmt.Errorf(".DB is required") 107 | } 108 | 109 | if opts.SourceDIR == "" { 110 | return fmt.Errorf(".SourceDIR is required") 111 | } 112 | if s, err := filepath.Abs(opts.SourceDIR); err == nil { 113 | opts.SourceDIR = s 114 | } else { 115 | return fmt.Errorf("failed to resolve SourceDIR %q: %w", opts.SourceDIR, err) 116 | } 117 | stat, err := os.Stat(opts.SourceDIR) 118 | if err != nil { 119 | return fmt.Errorf("%s: %w", opts.SourceDIR, err) 120 | } 121 | if !stat.IsDir() { 122 | return fmt.Errorf("migrations source dir %q is not a dir", opts.SourceDIR) 123 | } 124 | 125 | return nil 126 | } 127 | 128 | type dbMigrator struct { 129 | logger logr.Logger 130 | migrator *migrate.Migrate 131 | } 132 | 133 | func NewMigrator(opts *MigrateOptions) (*dbMigrator, error) { 134 | if err := opts.defaults(); err != nil { 135 | return nil, err 136 | } 137 | 138 | driver, err := sqlite3.WithInstance(opts.DB, &sqlite3.Config{ 139 | MigrationsTable: tableNameMigrations, 140 | }) 141 | if err != nil { 142 | return nil, err 143 | } 144 | migrator, err := migrate.NewWithDatabaseInstance( 145 | "file://"+opts.SourceDIR, 146 | "sqlite3", driver, 147 | ) 148 | if err != nil { 149 | return nil, err 150 | } 151 | 152 | rv := &dbMigrator{ 153 | logger: opts.Logger.WithName("db-migrator"), 154 | migrator: migrator, 155 | } 156 | 157 | return rv, nil 158 | } 159 | 160 | func handleMigrateError(logger logr.Logger, op string, migrateErr error) error { 161 | if migrateErr == nil { 162 | logger.Info("applied operation") 163 | return nil 164 | } 165 | 166 | if errors.Is(migrateErr, migrate.ErrNoChange) { 167 | // no update 168 | logger.V(8).Info("no pending migrations") 169 | return nil 170 | } 171 | 172 | var pathErr *fs.PathError 173 | if errors.As(migrateErr, &pathErr) { 174 | // no migrations set 175 | if pathErr.Op == "first" && errors.Is(pathErr.Err, fs.ErrNotExist) { 176 | logger.Info("no migrations to apply") 177 | return nil 178 | } 179 | } 180 | 181 | logger.Error(migrateErr, "failed to apply operation") 182 | return fmt.Errorf("%s: %w", op, migrateErr) 183 | } 184 | 185 | func (m *dbMigrator) Up(ctx context.Context, step int) error { 186 | logger := m.logger.WithName("up") 187 | logger.Info("applying operation") 188 | 189 | var migrateErr error 190 | 191 | if isApplyAllStep(step) { 192 | migrateErr = m.migrator.Up() 193 | } else { 194 | migrateErr = m.migrator.Steps(step) 195 | } 196 | 197 | return handleMigrateError(logger, "up", migrateErr) 198 | } 199 | 200 | func (m *dbMigrator) Down(ctx context.Context, step int) error { 201 | logger := m.logger.WithName("down") 202 | logger.Info("applying operation") 203 | 204 | var migrateErr error 205 | 206 | if isApplyAllStep(step) { 207 | migrateErr = m.migrator.Down() 208 | } else { 209 | migrateErr = m.migrator.Steps(-step) 210 | } 211 | 212 | return handleMigrateError(logger, "up", migrateErr) 213 | } 214 | -------------------------------------------------------------------------------- /query.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "mime" 10 | "net/http" 11 | "sort" 12 | "strconv" 13 | "strings" 14 | ) 15 | 16 | const ( 17 | queryParameterNameSelect = "select" 18 | queryParameterNameOrder = "order" 19 | queryParameterNameLimit = "limit" 20 | queryParameterNameOffset = "offset" 21 | queryParameterNameOnConflict = "on_conflict" 22 | 23 | headerNamePrefer = "Prefer" 24 | headerNameRangeUnit = "range-unit" 25 | headerNameRange = "range" 26 | 27 | logicalOperatorNot = "not" 28 | logicalOperatorAnd = "and" 29 | logicalOperatorOr = "or" 30 | 31 | doubleColonCastingOperator = "::" // NOTE: this is a PostgreSQL specific operator 32 | singleColonRenameOperator = ":" 33 | ) 34 | 35 | type CompiledQuery struct { 36 | Query string 37 | Values []interface{} 38 | } 39 | 40 | func (q CompiledQuery) String() string { 41 | return fmt.Sprintf("quey=%q values=%v", q.Query, q.Values) 42 | } 43 | 44 | type QueryCompiler interface { 45 | CompileAsSelect(table string) (CompiledQuery, error) 46 | CompileAsExactCount(table string) (CompiledQuery, error) 47 | CompileAsUpdate(table string) (CompiledQuery, error) 48 | CompileAsUpdateSingleEntry(table string) (CompiledQuery, error) 49 | CompileAsInsert(table string) (CompiledQuery, error) 50 | CompileAsDelete(table string) (CompiledQuery, error) 51 | CompileContentRangeHeader(totalCount string) string 52 | } 53 | 54 | type queryCompiler struct { 55 | req *http.Request 56 | } 57 | 58 | func NewQueryCompilerFromRequest(req *http.Request) QueryCompiler { 59 | return &queryCompiler{req: req} 60 | } 61 | 62 | func (c *queryCompiler) getQueryParameters(name string) []string { 63 | qp := c.req.URL.Query() 64 | if !qp.Has(name) { 65 | return nil 66 | } 67 | return qp[name] 68 | } 69 | 70 | func (c *queryCompiler) getQueryParameter(name string) string { 71 | qp := c.req.URL.Query() 72 | if !qp.Has(name) { 73 | return "" 74 | } 75 | return qp.Get(name) 76 | } 77 | 78 | func (c *queryCompiler) CompileAsSelect(table string) (CompiledQuery, error) { 79 | rv := CompiledQuery{} 80 | 81 | rv.Query = fmt.Sprintf( 82 | "select %s from %s", 83 | strings.Join(c.getSelectResultColumns(), ", "), 84 | table, 85 | ) 86 | 87 | parsedQueryClauses, err := c.getQueryClauses() 88 | if err != nil { 89 | return rv, err 90 | } 91 | var queryClauses []string 92 | for _, qc := range parsedQueryClauses { 93 | queryClauses = append(queryClauses, qc.Expr) 94 | rv.Values = append(rv.Values, qc.Values...) 95 | } 96 | if len(queryClauses) > 0 { 97 | rv.Query = fmt.Sprintf("%s where %s", rv.Query, strings.Join(queryClauses, " and ")) 98 | } 99 | 100 | orderClauses, err := c.getOrderClauses() 101 | if err != nil { 102 | return rv, err 103 | } 104 | if len(orderClauses) > 0 { 105 | rv.Query = fmt.Sprintf("%s order by %s", rv.Query, strings.Join(orderClauses, ", ")) 106 | } 107 | 108 | limit, offset, err := c.getLimitOffset() 109 | switch { 110 | case err == nil: 111 | rv.Query = fmt.Sprintf("%s limit %d", rv.Query, limit) 112 | if offset != 0 { 113 | rv.Query = fmt.Sprintf("%s offset %d", rv.Query, offset) 114 | } 115 | case errors.Is(err, errNoLimitOffset): 116 | // no limit/offset 117 | default: 118 | return rv, err 119 | } 120 | 121 | return rv, nil 122 | } 123 | 124 | func (c *queryCompiler) CompileAsExactCount(table string) (CompiledQuery, error) { 125 | rv := CompiledQuery{} 126 | 127 | rv.Query = fmt.Sprintf( 128 | "select count(1) from %s", 129 | table, 130 | ) 131 | 132 | parsedQueryClauses, err := c.getQueryClauses() 133 | if err != nil { 134 | return rv, err 135 | } 136 | var queryClauses []string 137 | for _, qc := range parsedQueryClauses { 138 | queryClauses = append(queryClauses, qc.Expr) 139 | rv.Values = append(rv.Values, qc.Values...) 140 | } 141 | if len(queryClauses) > 0 { 142 | rv.Query = fmt.Sprintf("%s where %s", rv.Query, strings.Join(queryClauses, " and ")) 143 | } 144 | 145 | return rv, nil 146 | } 147 | 148 | func (c *queryCompiler) CompileAsUpdate(table string) (CompiledQuery, error) { 149 | rv := CompiledQuery{} 150 | 151 | payload, err := c.getInputPayload() 152 | if err != nil { 153 | return rv, err 154 | } 155 | if len(payload.Columns) < 1 { 156 | return rv, ErrBadRequest.WithHint("no columns to insert") 157 | } 158 | if len(payload.Payload) < 1 { 159 | return rv, ErrBadRequest.WithHint("no data to insert") 160 | } 161 | if len(payload.Payload) > 1 { 162 | return rv, ErrBadRequest.WithHint("too many data to update") 163 | } 164 | 165 | columns := payload.GetSortedColumns() 166 | updateValues := payload.Payload[0] 167 | var columnPlaceholders []string 168 | for _, column := range columns { 169 | columnPlaceholders = append(columnPlaceholders, fmt.Sprintf("%s = ?", column)) 170 | rv.Values = append(rv.Values, updateValues[column]) 171 | } 172 | 173 | rv.Query = fmt.Sprintf( 174 | "update %s set %s", 175 | table, 176 | strings.Join(columnPlaceholders, ", "), 177 | ) 178 | 179 | parsedQueryClauses, err := c.getQueryClauses() 180 | if err != nil { 181 | return rv, err 182 | } 183 | var qcs []string 184 | for _, qc := range parsedQueryClauses { 185 | qcs = append(qcs, qc.Expr) 186 | rv.Values = append(rv.Values, qc.Values...) 187 | } 188 | if len(qcs) > 0 { 189 | rv.Query = fmt.Sprintf("%s where %s", rv.Query, strings.Join(qcs, " and ")) 190 | } 191 | 192 | return rv, nil 193 | } 194 | 195 | func (c *queryCompiler) CompileAsUpdateSingleEntry(table string) (CompiledQuery, error) { 196 | rv := CompiledQuery{} 197 | 198 | payload, err := c.getInputPayload() 199 | if err != nil { 200 | return rv, err 201 | } 202 | if len(payload.Columns) < 1 { 203 | return rv, ErrBadRequest.WithHint("no columns to insert") 204 | } 205 | if len(payload.Payload) < 1 { 206 | return rv, ErrBadRequest.WithHint("no data to insert") 207 | } 208 | if len(payload.Payload) > 1 { 209 | return rv, ErrBadRequest.WithHint("too many data to update") 210 | } 211 | 212 | columns := payload.GetSortedColumns() 213 | updateValues := payload.Payload[0] 214 | var columnPlaceholders []string 215 | for _, column := range columns { 216 | columnPlaceholders = append(columnPlaceholders, fmt.Sprintf("%s = ?", column)) 217 | rv.Values = append(rv.Values, updateValues[column]) 218 | } 219 | 220 | rv.Query = fmt.Sprintf( 221 | "update %s set %s", 222 | table, 223 | strings.Join(columnPlaceholders, ", "), 224 | ) 225 | 226 | parsedQueryClauses, err := c.getQueryClauses() 227 | if err != nil { 228 | return rv, err 229 | } 230 | if len(parsedQueryClauses) < 1 { 231 | return rv, ErrBadRequest.WithHint("expect to specifiy primary key query") 232 | } 233 | var qcs []string 234 | for _, qc := range parsedQueryClauses { 235 | qcs = append(qcs, qc.Expr) 236 | rv.Values = append(rv.Values, qc.Values...) 237 | } 238 | rv.Query = fmt.Sprintf("%s where %s", rv.Query, strings.Join(qcs, " and ")) 239 | // make sure only one row will be updated 240 | // Needs SQLITE_ENABLE_UPDATE_DELETE_LIMIT , but it's not available in mattn/sqlite3 241 | // rv.Query = fmt.Sprintf("%s limit 1", rv.Query) 242 | 243 | return rv, nil 244 | } 245 | 246 | func (c *queryCompiler) CompileAsInsert(table string) (CompiledQuery, error) { 247 | rv := CompiledQuery{} 248 | 249 | preference, err := ParsePreferenceFromRequest(c.req) 250 | if err != nil { 251 | return rv, err 252 | } 253 | 254 | payload, err := c.getInputPayload() 255 | if err != nil { 256 | return rv, err 257 | } 258 | if len(payload.Columns) < 1 { 259 | return rv, ErrBadRequest.WithHint("no columns to insert") 260 | } 261 | if len(payload.Payload) < 1 { 262 | return rv, ErrBadRequest.WithHint("no data to insert") 263 | } 264 | 265 | columns := payload.GetSortedColumns() 266 | 267 | values := payload.GetValues(columns) 268 | var valuePlaceholders []string 269 | for range values { 270 | valuePlaceholders = append( 271 | valuePlaceholders, 272 | fmt.Sprintf("(%s?)", strings.Repeat("?, ", len(columns)-1)), 273 | ) 274 | } 275 | 276 | rv.Query = fmt.Sprintf( 277 | `insert into %s (%s) values %s`, 278 | table, 279 | strings.Join(columns, ", "), 280 | strings.Join(valuePlaceholders, ", "), 281 | ) 282 | 283 | for _, v := range values { 284 | rv.Values = append(rv.Values, v...) 285 | } 286 | 287 | if preference.Resolution != resolutionNone { 288 | // FIXME: this is a potential sql injection vulnerability 289 | var onConflictColumns []string 290 | v := c.getQueryParameter(queryParameterNameOnConflict) 291 | if v != "" { 292 | onConflictColumns = strings.Split(v, ",") 293 | } 294 | var onConflictColumnsClause string 295 | if len(onConflictColumns) > 0 { 296 | onConflictColumnsClause = fmt.Sprintf(" (%s)", strings.Join(onConflictColumns, ", ")) 297 | } 298 | 299 | switch preference.Resolution { 300 | case resolutionIgnoreDuplicates: 301 | rv.Query = fmt.Sprintf("%s on conflict%s do nothing", rv.Query, onConflictColumnsClause) 302 | case resolutionMergeDuplicates: 303 | var excludedColumns []string 304 | for _, column := range columns { 305 | excludedColumns = append(excludedColumns, fmt.Sprintf("%s = excluded.%s", column, column)) 306 | } 307 | rv.Query = fmt.Sprintf( 308 | "%s on conflict%s do update set %s", 309 | rv.Query, 310 | onConflictColumnsClause, 311 | strings.Join(excludedColumns, ", "), 312 | ) 313 | } 314 | } 315 | 316 | return rv, nil 317 | } 318 | 319 | func (c *queryCompiler) CompileAsDelete(table string) (CompiledQuery, error) { 320 | rv := CompiledQuery{} 321 | 322 | rv.Query = fmt.Sprintf(`delete from %s`, table) 323 | 324 | parsedQueryClauses, err := c.getQueryClauses() 325 | if err != nil { 326 | return rv, err 327 | } 328 | var qcs []string 329 | for _, qc := range parsedQueryClauses { 330 | qcs = append(qcs, qc.Expr) 331 | rv.Values = append(rv.Values, qc.Values...) 332 | } 333 | if len(qcs) > 0 { 334 | rv.Query = fmt.Sprintf("%s where %s", rv.Query, strings.Join(qcs, " and ")) 335 | } 336 | 337 | return rv, nil 338 | } 339 | 340 | func getSelectResultColumn(columnName string) string { 341 | // newName:name::text 342 | // => cast(name as text) as newName 343 | 344 | var ( 345 | columnType string 346 | targetColumnName string 347 | ) 348 | 349 | if strings.Contains(columnName, doubleColonCastingOperator) { 350 | ps := strings.SplitN(columnName, doubleColonCastingOperator, 2) 351 | if len(ps) == 2 { 352 | // is a valid casting call 353 | columnName = ps[0] 354 | columnType = ps[1] 355 | } 356 | // NOTE: if it's not a valid casting, since the columnType is still empty, 357 | // no casting will be applied 358 | } 359 | 360 | if strings.Contains(columnName, singleColonRenameOperator) { 361 | ps := strings.SplitN(columnName, singleColonRenameOperator, 2) 362 | if len(ps) == 2 { 363 | // is a valid renaming call 364 | targetColumnName = ps[0] 365 | columnName = ps[1] 366 | } 367 | // NOTE: if it's not a valid renaming, since the targetColumnName is still empty, 368 | // no renaming will be applied 369 | } 370 | 371 | if columnType == "" { 372 | if targetColumnName == "" { 373 | return columnName 374 | } 375 | return fmt.Sprintf("%s as %s", columnName, targetColumnName) 376 | } else { 377 | if targetColumnName == "" { 378 | targetColumnName = columnName 379 | } 380 | return fmt.Sprintf("cast(%s as %s) as %s", columnName, columnType, targetColumnName) 381 | } 382 | } 383 | 384 | func (c *queryCompiler) getSelectResultColumns() []string { 385 | v := c.getQueryParameter(queryParameterNameSelect) 386 | if v == "" { 387 | return []string{"*"} 388 | } 389 | 390 | vs := strings.Split(v, ",") 391 | // TOOD: support renaming 392 | for idx := range vs { 393 | vs[idx] = getSelectResultColumn(vs[idx]) 394 | } 395 | 396 | return vs 397 | } 398 | 399 | func (c *queryCompiler) getQueryClauses() ([]CompiledQueryParameter, error) { 400 | var rv []CompiledQueryParameter 401 | for k := range c.req.URL.Query() { 402 | if !c.isColumnName(k) { 403 | continue 404 | } 405 | 406 | vs, err := c.getQueryClausesByColumn(k) 407 | if err != nil { 408 | return nil, err 409 | } 410 | if len(vs) < 1 { 411 | continue 412 | } 413 | 414 | rv = append(rv, vs...) 415 | } 416 | 417 | return rv, nil 418 | } 419 | 420 | func (c *queryCompiler) isColumnName(s string) bool { 421 | switch strings.ToLower(s) { 422 | case queryParameterNameSelect, 423 | queryParameterNameOrder, 424 | queryParameterNameLimit, 425 | queryParameterNameOffset, 426 | queryParameterNameOnConflict: 427 | return false 428 | default: 429 | return true 430 | } 431 | } 432 | 433 | func (c *queryCompiler) getQueryClausesByColumn( 434 | column string, 435 | ) ([]CompiledQueryParameter, error) { 436 | vs := c.getQueryParameters(column) 437 | if len(vs) < 1 { 438 | return nil, nil 439 | } 440 | 441 | var rv []CompiledQueryParameter 442 | for _, v := range vs { 443 | ps, err := c.getQueryClausesByInput(column, v) 444 | if err != nil { 445 | return nil, err 446 | } 447 | if len(ps) < 1 { 448 | continue 449 | } 450 | rv = append(rv, ps...) 451 | } 452 | 453 | return rv, nil 454 | } 455 | 456 | func (c *queryCompiler) getQueryClausesByInput( 457 | column string, 458 | s string, 459 | ) ([]CompiledQueryParameter, error) { 460 | if s == "" { 461 | return nil, nil 462 | } 463 | 464 | switch column { 465 | case logicalOperatorAnd, logicalOperatorOr: 466 | // or=a.eq.1,b.eq.2 => or(a.eq.1, b.eq.2) 467 | return parseQueryClauses(fmt.Sprintf("%s(%s)", column, s)) 468 | default: 469 | // id=eq.1 470 | return parseQueryClauses(fmt.Sprintf("%s.%s", column, s)) 471 | } 472 | } 473 | 474 | var orderByNulls = map[string]string{ 475 | "nullslast": "nulls last", 476 | "nullsfirst": "nulls first", 477 | } 478 | 479 | func (c *queryCompiler) getOrderClauses() ([]string, error) { 480 | v := c.getQueryParameter(queryParameterNameOrder) 481 | if v == "" { 482 | return nil, nil 483 | } 484 | 485 | translateOrderBy := func(s string) string { 486 | if v, exists := orderByNulls[s]; exists { 487 | return v 488 | } 489 | return s 490 | } 491 | 492 | var vs []string 493 | for _, v := range strings.Split(v, ",") { 494 | ps := strings.Split(v, ".") 495 | switch { 496 | case len(ps) == 1: 497 | vs = append(vs, ps[0]) 498 | case len(ps) == 2: 499 | // a.asc -> a asc 500 | // a.nullslast -> a nulls last 501 | vs = append(vs, fmt.Sprintf("%s %s", ps[0], translateOrderBy(ps[1]))) 502 | case len(ps) == 3: 503 | // a.asc.nullslast 504 | vs = append(vs, fmt.Sprintf("%s %s %s", ps[0], ps[1], translateOrderBy(ps[2]))) 505 | default: 506 | // invalid 507 | return nil, fmt.Errorf("invalid order by clause: %s", v) 508 | } 509 | } 510 | 511 | return vs, nil 512 | } 513 | 514 | var errNoLimitOffset = errors.New("no limit offset") 515 | 516 | func (c *queryCompiler) CompileContentRangeHeader(totalCount string) string { 517 | limit, offset, err := c.getLimitOffset() 518 | if err != nil { 519 | // unable to infer limit/offset 520 | return "" 521 | } 522 | 523 | if limit < 0 { 524 | // unbound range 525 | return fmt.Sprintf("%d-/%s", offset, totalCount) 526 | } 527 | 528 | return fmt.Sprintf("%d-%d/%s", offset, offset+limit-1, totalCount) 529 | } 530 | 531 | func (c *queryCompiler) getLimitOffset() (limit int64, offset int64, err error) { 532 | limit, offset, err = c.getLimitOffsetFromHeader() 533 | if err == nil { 534 | return limit, offset, nil 535 | } 536 | if !errors.Is(err, errNoLimitOffset) { 537 | return 0, 0, err 538 | } 539 | return c.getLimitOffsetFromQueryParameter() 540 | } 541 | 542 | func (c *queryCompiler) getLimitOffsetFromHeader() (int64, int64, error) { 543 | rangeValue := c.req.Header.Get(headerNameRange) 544 | if rangeValue == "" { 545 | return 0, 0, errNoLimitOffset 546 | } 547 | 548 | ps := strings.SplitN(rangeValue, "-", 2) 549 | if len(ps) < 1 { 550 | return 0, 0, errNoLimitOffset 551 | } 552 | 553 | offset, err := strconv.ParseInt(ps[0], 10, 64) 554 | if err != nil { 555 | return 0, 0, err 556 | } 557 | if ps[1] == "" { 558 | // no limit, per: https://www.sqlite.org/lang_select.html#limitoffset 559 | // If the LIMIT expression evaluates to a negative value, 560 | // then there is no upper bound on the number of rows returned 561 | return -1, offset, nil 562 | } 563 | to, err := strconv.ParseInt(ps[1], 10, 64) 564 | if err != nil { 565 | return 0, 0, err 566 | } 567 | 568 | return to - offset + 1, offset, nil 569 | } 570 | 571 | func (c *queryCompiler) getLimitOffsetFromQueryParameter() (int64, int64, error) { 572 | getInt64 := func(qp string) (int64, error) { 573 | v := c.getQueryParameter(qp) 574 | if v == "" { 575 | return 0, errNoLimitOffset 576 | } 577 | return strconv.ParseInt(v, 10, 64) 578 | } 579 | 580 | limit, err := getInt64(queryParameterNameLimit) 581 | if err != nil { 582 | return 0, 0, err 583 | } 584 | offset, err := getInt64(queryParameterNameOffset) 585 | switch { 586 | case err == nil: 587 | return limit, offset, nil 588 | case errors.Is(err, errNoLimitOffset): 589 | // offset is optional 590 | return limit, 0, nil 591 | default: 592 | return 0, 0, err 593 | } 594 | } 595 | 596 | func (c *queryCompiler) getInputPayload() (InputPayloadWithColumns, error) { 597 | contentType := c.req.Header.Get("content-type") 598 | if contentType == "" { 599 | contentType = "application/octet-stream" 600 | } 601 | 602 | for _, v := range strings.Split(contentType, ",") { 603 | mt, _, err := mime.ParseMediaType(v) 604 | if err != nil { 605 | continue 606 | } 607 | 608 | switch strings.ToLower(mt) { 609 | case "application/json": 610 | payload, err := c.tryReadInputPayloadAsJSON() 611 | if err != nil { 612 | continue 613 | } 614 | return payload, nil 615 | default: 616 | continue 617 | } 618 | } 619 | 620 | return InputPayloadWithColumns{}, ErrUnsupportedMediaType 621 | } 622 | 623 | func (c *queryCompiler) tryReadInputPayloadAsJSON() (InputPayloadWithColumns, error) { 624 | rv := InputPayloadWithColumns{ 625 | Columns: map[string]struct{}{}, 626 | } 627 | 628 | body, err := c.readyRequestBody() 629 | if err != nil { 630 | return rv, err 631 | } 632 | 633 | // TODO: we need a Peek method from json.Decoder 634 | enc := json.NewDecoder(bytes.NewBuffer(body)) 635 | tok, err := enc.Token() 636 | if err != nil { 637 | return rv, err 638 | } 639 | switch tok { 640 | case json.Delim('['): 641 | // a json array 642 | var ps []map[string]interface{} 643 | if err := json.Unmarshal(body, &ps); err != nil { 644 | return rv, err 645 | } 646 | rv.Payload = append(rv.Payload, ps...) 647 | default: 648 | // try as single object 649 | var p map[string]interface{} 650 | if err := json.Unmarshal(body, &p); err != nil { 651 | return rv, err 652 | } 653 | rv.Payload = append(rv.Payload, p) 654 | } 655 | 656 | for _, p := range rv.Payload { 657 | for k := range p { 658 | rv.Columns[k] = struct{}{} 659 | } 660 | } 661 | 662 | return rv, nil 663 | } 664 | 665 | func (c *queryCompiler) readyRequestBody() ([]byte, error) { 666 | source := c.req.Body 667 | defer source.Close() 668 | b, err := io.ReadAll(source) 669 | if err != nil { 670 | return nil, fmt.Errorf("read request body: %w", err) 671 | } 672 | c.req.Body = io.NopCloser(bytes.NewBuffer(b)) 673 | 674 | return b, nil 675 | } 676 | 677 | type CompiledQueryParameter struct { 678 | Expr string 679 | Values []interface{} 680 | } 681 | 682 | func negateCompiledQueryParameters( 683 | qps []CompiledQueryParameter, 684 | err error, 685 | ) ([]CompiledQueryParameter, error) { 686 | if err != nil { 687 | return qps, err 688 | } 689 | 690 | if len(qps) < 1 { 691 | return qps, nil 692 | } 693 | 694 | negatedResult := CompiledQueryParameter{} 695 | 696 | var subExprs []string 697 | for _, p := range qps { 698 | subExprs = append(subExprs, p.Expr) 699 | negatedResult.Values = append(negatedResult.Values, p.Values...) 700 | } 701 | negatedResult.Expr = fmt.Sprintf( 702 | "(not (%s))", 703 | strings.Join(subExprs, " and "), 704 | ) 705 | 706 | return []CompiledQueryParameter{negatedResult}, nil 707 | } 708 | 709 | func joinCompiledQueryParameters( 710 | operator string, 711 | ) func([]CompiledQueryParameter, error) ([]CompiledQueryParameter, error) { 712 | return func( 713 | qps []CompiledQueryParameter, 714 | err error, 715 | ) ([]CompiledQueryParameter, error) { 716 | if err != nil { 717 | return qps, err 718 | } 719 | 720 | if len(qps) < 1 { 721 | return qps, nil 722 | } 723 | 724 | rv := CompiledQueryParameter{} 725 | 726 | var subExprs []string 727 | for _, p := range qps { 728 | subExprs = append(subExprs, p.Expr) 729 | rv.Values = append(rv.Values, p.Values...) 730 | } 731 | rv.Expr = fmt.Sprintf( 732 | "(%s)", 733 | strings.Join(subExprs, fmt.Sprintf(" %s ", operator)), 734 | ) 735 | 736 | return []CompiledQueryParameter{rv}, nil 737 | } 738 | } 739 | 740 | var ( 741 | andCompiledQueryParameters = joinCompiledQueryParameters("and") 742 | orCompiledQueryParameters = joinCompiledQueryParameters("or") 743 | ) 744 | 745 | // (age.eq.14,not.and(age.gte.11,age.lte.17)) => [age.eq.14, not.and(age.gte.11,age.lte.17)] 746 | func tokenizeSubQueries(s string) []string { 747 | for strings.HasPrefix(s, "(") && strings.HasSuffix(s, ")") { 748 | s = s[1 : len(s)-1] 749 | } 750 | 751 | var rv []string 752 | 753 | var current string 754 | subQueryLevel := 0 755 | for _, c := range s { 756 | if c == ',' && subQueryLevel == 0 { 757 | rv = append(rv, current) 758 | current = "" 759 | continue 760 | } 761 | if c == '(' { 762 | subQueryLevel += 1 763 | } 764 | if c == ')' { 765 | subQueryLevel -= 1 766 | } 767 | current = current + string(c) 768 | } 769 | if current != "" { 770 | rv = append(rv, current) 771 | } 772 | 773 | return rv 774 | } 775 | 776 | // parseQueryClauses parses user input queries to query parameters. 777 | // FIXME: this is a very naive and slow (O(n^2)) parser. We should employ a proper lexer & parser. 778 | func parseQueryClauses(s string) ([]CompiledQueryParameter, error) { 779 | const ( 780 | logicalOperatorNotPrefix = logicalOperatorNot + "." 781 | logicalOperatorAndPrefix = logicalOperatorAnd + "(" 782 | logicalOperatorOrPrefix = logicalOperatorOr + "(" 783 | ) 784 | 785 | switch { 786 | case s == "": 787 | return nil, nil 788 | case strings.HasPrefix(s, "("): 789 | if !strings.HasSuffix(s, ")") { 790 | return nil, ErrBadRequest.WithHint(fmt.Sprintf("incomplete sub query: %q", s)) 791 | } 792 | subQueries := tokenizeSubQueries(s) 793 | if len(subQueries) < 1 { 794 | return nil, ErrBadRequest.WithHint(fmt.Sprintf("invalid sub query: %q", s)) 795 | } 796 | 797 | var rv []CompiledQueryParameter 798 | for _, subQuery := range subQueries { 799 | q, err := parseQueryClauses(subQuery) 800 | if err != nil { 801 | return nil, err 802 | } 803 | rv = append(rv, q...) 804 | } 805 | 806 | return rv, nil 807 | case strings.HasPrefix(s, logicalOperatorNotPrefix): 808 | return negateCompiledQueryParameters(parseQueryClauses(s[len(logicalOperatorNotPrefix):])) 809 | case strings.HasPrefix(s, logicalOperatorAndPrefix): 810 | return andCompiledQueryParameters(parseQueryClauses(s[len(logicalOperatorAnd):])) 811 | case strings.HasPrefix(s, logicalOperatorOrPrefix): 812 | return orCompiledQueryParameters(parseQueryClauses(s[len(logicalOperatorOr):])) 813 | default: 814 | // column.operator.value | column.not.operator.value 815 | ps := strings.SplitN(s, ".", 3) 816 | if len(ps) != 3 { 817 | return nil, ErrBadRequest.WithHint(fmt.Sprintf("invalid query clause: %q", s)) 818 | } 819 | column, op, value := ps[0], ps[1], ps[2] 820 | negate := false 821 | if op == logicalOperatorNot { 822 | negate = true 823 | ps := strings.SplitN(value, ".", 2) 824 | if len(ps) != 2 { 825 | return nil, ErrBadRequest.WithHint(fmt.Sprintf("invalid query clause: %q", s)) 826 | } 827 | op, value = ps[0], ps[1] 828 | } 829 | 830 | opProcess, exists := queryOpereators[op] 831 | if !exists { 832 | return nil, ErrUnsupportedOperator(s) 833 | } 834 | 835 | rv, err := opProcess(column, op, value) 836 | if err != nil { 837 | return nil, err 838 | } 839 | if negate { 840 | rv, _ = negateCompiledQueryParameters(rv, nil) 841 | } 842 | return rv, nil 843 | } 844 | } 845 | 846 | type queryOpereatorUserInputParseFunc func(column string, userInput string, value string) ([]CompiledQueryParameter, error) 847 | 848 | func mapUserInputAsUnaryQuery(op string) queryOpereatorUserInputParseFunc { 849 | return func(column string, userInput string, value string) ([]CompiledQueryParameter, error) { 850 | rv := []CompiledQueryParameter{ 851 | { 852 | Expr: fmt.Sprintf("%s %s ?", column, op), 853 | Values: []interface{}{value}, 854 | }, 855 | } 856 | 857 | return rv, nil 858 | } 859 | } 860 | 861 | func mapAsInQuery(column string, userInput string, value string) ([]CompiledQueryParameter, error) { 862 | value = strings.TrimPrefix(value, "(") 863 | value = strings.TrimSuffix(value, ")") 864 | value = fmt.Sprintf("[%s]", value) 865 | var ps []interface{} 866 | // FIXME: this is not 100% safe to parse user input as JSON 867 | if err := json.Unmarshal([]byte(value), &ps); err != nil { 868 | return nil, err 869 | } 870 | 871 | rv := []CompiledQueryParameter{ 872 | { 873 | Expr: fmt.Sprintf("%s IN (%s)", column, strings.Repeat("?,", len(ps)-1)+"?"), 874 | Values: ps, 875 | }, 876 | } 877 | 878 | return rv, nil 879 | } 880 | 881 | func mapAsIsQuery(column string, userInput string, value string) ([]CompiledQueryParameter, error) { 882 | rv := CompiledQueryParameter{ 883 | Expr: fmt.Sprintf("%s IS ?", column), 884 | Values: []interface{}{}, 885 | } 886 | 887 | switch strings.ToLower(value) { 888 | case "null": 889 | rv.Values = append(rv.Values, nil) 890 | case "false": 891 | rv.Values = append(rv.Values, false) 892 | case "true": 893 | rv.Values = append(rv.Values, true) 894 | default: 895 | return nil, ErrUnsupportedOperator(fmt.Sprintf("%s.%s", userInput, value)) 896 | } 897 | 898 | return []CompiledQueryParameter{rv}, nil 899 | } 900 | 901 | // ref: https://postgrest.org/en/stable/api.html#operators 902 | var queryOpereators = map[string]queryOpereatorUserInputParseFunc{ 903 | "eq": mapUserInputAsUnaryQuery("="), 904 | "gt": mapUserInputAsUnaryQuery(">"), "ge": mapUserInputAsUnaryQuery(">="), 905 | "lt": mapUserInputAsUnaryQuery("<"), "le": mapUserInputAsUnaryQuery("<="), 906 | "neq": mapUserInputAsUnaryQuery("!="), 907 | "like": mapUserInputAsUnaryQuery("LIKE"), "ilike": mapUserInputAsUnaryQuery("ILIKE"), 908 | "in": mapAsInQuery, 909 | "is": mapAsIsQuery, 910 | // fts / plfts / phfts / wfts are unsupported 911 | // cs / cd / ov are unsupported 912 | // sl / sr / nxr / nxl / adj are unsupported 913 | } 914 | 915 | type InputPayloadWithColumns struct { 916 | Columns map[string]struct{} 917 | Payload []map[string]interface{} 918 | } 919 | 920 | func (p InputPayloadWithColumns) GetSortedColumns() []string { 921 | columns := make([]string, 0, len(p.Columns)) 922 | for column := range p.Columns { 923 | columns = append(columns, column) 924 | } 925 | sort.Strings(columns) 926 | return columns 927 | } 928 | 929 | func (p InputPayloadWithColumns) GetValues(columns []string) [][]interface{} { 930 | var rv [][]interface{} 931 | for _, p := range p.Payload { 932 | var row []interface{} 933 | for _, column := range columns { 934 | v, exists := p[column] 935 | if exists { 936 | row = append(row, v) 937 | } else { 938 | row = append(row, nil) 939 | } 940 | } 941 | rv = append(rv, row) 942 | } 943 | 944 | return rv 945 | } 946 | 947 | // CountMethod specifies the count method for the request. 948 | type CountMethod string 949 | 950 | const ( 951 | countNone CountMethod = "" // fallback 952 | countExact CountMethod = "exact" 953 | // TODO: support planned / estimated count 954 | ) 955 | 956 | // Valid checks if the count method is valid. 957 | func (c CountMethod) Valid() bool { 958 | switch c { 959 | case countNone, countExact: 960 | return true 961 | default: 962 | return false 963 | } 964 | } 965 | 966 | // ResolutionMethod specifies the conflict resolution for the request. 967 | type ResolutionMethod string 968 | 969 | const ( 970 | resolutionNone = "" // fallback 971 | resolutionMergeDuplicates = "merge-duplicates" 972 | resolutionIgnoreDuplicates = "ignore-duplicates" 973 | ) 974 | 975 | // Valid checks if the resolution method is valid. 976 | func (r ResolutionMethod) Valid() bool { 977 | switch r { 978 | case resolutionNone, resolutionIgnoreDuplicates, resolutionMergeDuplicates: 979 | return true 980 | default: 981 | return false 982 | } 983 | } 984 | 985 | type Preference struct { 986 | Resolution ResolutionMethod 987 | Count CountMethod 988 | // TODO: retrun 989 | } 990 | 991 | func ParsePreferenceFromRequest(req *http.Request) (Preference, error) { 992 | var rv Preference 993 | 994 | v := req.Header.Get(headerNamePrefer) 995 | if v == "" { 996 | return rv, nil 997 | } 998 | 999 | for _, p := range strings.Split(v, ",") { 1000 | p = strings.TrimSpace(p) 1001 | if p == "" { 1002 | continue 1003 | } 1004 | // a=b => a,b 1005 | ps := strings.SplitN(p, "=", 2) 1006 | if len(ps) < 2 { 1007 | continue 1008 | } 1009 | 1010 | switch strings.ToLower(ps[0]) { 1011 | case "count": 1012 | countMethod := CountMethod(strings.ToLower(ps[1])) 1013 | if countMethod.Valid() { 1014 | rv.Count = countMethod 1015 | } else { 1016 | return rv, ErrBadRequest.WithHint(fmt.Sprintf("unsupported count preference: %s", ps[1])) 1017 | } 1018 | case "resolution": 1019 | resolution := ResolutionMethod(strings.ToLower(ps[1])) 1020 | if resolution.Valid() { 1021 | rv.Resolution = resolution 1022 | } else { 1023 | return rv, ErrBadRequest.WithHint(fmt.Sprintf("unsupported resolution preference: %s", ps[1])) 1024 | } 1025 | } 1026 | } 1027 | 1028 | return rv, nil 1029 | } 1030 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "net/http" 9 | "os" 10 | "os/signal" 11 | "syscall" 12 | "time" 13 | 14 | "github.com/go-chi/chi/v5" 15 | "github.com/go-chi/chi/v5/middleware" 16 | "github.com/go-chi/cors" 17 | "github.com/go-logr/logr" 18 | "github.com/jmoiron/sqlx" 19 | "github.com/spf13/cobra" 20 | "github.com/spf13/pflag" 21 | ) 22 | 23 | const ( 24 | routeVarTableOrView = "tableOrView" 25 | ) 26 | 27 | type ServerOptions struct { 28 | Logger logr.Logger 29 | Addr string 30 | AuthOptions ServerAuthOptions 31 | SecurityOptions ServerSecurityOptions 32 | Queryer sqlx.QueryerContext 33 | Execer sqlx.ExecerContext 34 | } 35 | 36 | func (opts *ServerOptions) bindCLIFlags(fs *pflag.FlagSet) { 37 | fs.StringVar(&opts.Addr, "http-addr", ":8080", "server listen address") 38 | 39 | opts.AuthOptions.bindCLIFlags(fs) 40 | opts.SecurityOptions.bindCLIFlags(fs) 41 | } 42 | 43 | func (opts *ServerOptions) defaults() error { 44 | if err := opts.AuthOptions.defaults(); err != nil { 45 | return err 46 | } 47 | if err := opts.SecurityOptions.defaults(); err != nil { 48 | return err 49 | } 50 | 51 | if opts.Logger.GetSink() == nil { 52 | opts.Logger = logr.Discard() 53 | } 54 | 55 | if opts.Addr == "" { 56 | opts.Addr = ":8080" 57 | } 58 | 59 | if opts.Queryer == nil { 60 | return fmt.Errorf(".Queryer is required") 61 | } 62 | 63 | if opts.Execer == nil { 64 | return fmt.Errorf(".Execer is required") 65 | } 66 | 67 | return nil 68 | } 69 | 70 | type dbServer struct { 71 | logger logr.Logger 72 | server *http.Server 73 | queryer sqlx.QueryerContext 74 | execer sqlx.ExecerContext 75 | } 76 | 77 | func NewServer(opts *ServerOptions) (*dbServer, error) { 78 | if err := opts.defaults(); err != nil { 79 | return nil, err 80 | } 81 | 82 | rv := &dbServer{ 83 | logger: opts.Logger.WithName("db-server"), 84 | server: &http.Server{ 85 | Addr: opts.Addr, 86 | // TODO: make it configurable 87 | ReadHeaderTimeout: 5 * time.Second, 88 | }, 89 | queryer: opts.Queryer, 90 | execer: opts.Execer, 91 | } 92 | 93 | serverMux := chi.NewRouter() 94 | 95 | // TODO: allow specifying cors config from cli / table 96 | serverMux.Use( 97 | middleware.RequestID, 98 | middleware.RealIP, 99 | serverLogger(rv.logger), 100 | cors.AllowAll().Handler, 101 | ) 102 | 103 | { 104 | serverMux. 105 | With( 106 | opts.AuthOptions.createAuthMiddleware(func(w http.ResponseWriter, err error) { 107 | metricsAuthFailedRequestsTotal.Inc() 108 | rv.responseError(w, err) 109 | }), 110 | opts.SecurityOptions.createTableOrViewAccessCheckMiddleware(func(w http.ResponseWriter, err error) { 111 | metricsAccessCheckFailedRequestsTotal.Inc() 112 | rv.responseError(w, err) 113 | }), 114 | ). 115 | Group(func(r chi.Router) { 116 | routePattern := fmt.Sprintf("/{%s:[^/]+}", routeVarTableOrView) 117 | r.With(recordRequestMetrics("queryTableOrView")).Get(routePattern, rv.handleQueryTableOrView) 118 | r.With(recordRequestMetrics("insertTable")).Post(routePattern, rv.handleInsertTable) 119 | r.With(recordRequestMetrics("updateTable")).Patch(routePattern, rv.handleUpdateTable) 120 | r.With(recordRequestMetrics("updateSingleEntity")).Put(routePattern, rv.handleUpdateSingleEntity) 121 | r.With(recordRequestMetrics("deleteTable")).Delete(routePattern, rv.handleDeleteTable) 122 | }) 123 | } 124 | 125 | rv.server.Handler = serverMux 126 | 127 | return rv, nil 128 | } 129 | 130 | func (server *dbServer) Start(done <-chan struct{}) { 131 | go server.server.ListenAndServe() 132 | 133 | server.logger.Info("server started", "addr", server.server.Addr) 134 | <-done 135 | 136 | server.logger.Info("shutting down server") 137 | shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) 138 | defer cancel() 139 | server.server.Shutdown(shutdownCtx) 140 | } 141 | 142 | func (server *dbServer) responseHeader(w http.ResponseWriter, statusCode int) { 143 | w.Header().Set("Server", ServerVersion) 144 | w.WriteHeader(statusCode) 145 | } 146 | 147 | func (server *dbServer) responseError(w http.ResponseWriter, err error) { 148 | var serverError *ServerError 149 | switch { 150 | case errors.As(err, &serverError): 151 | server.responseData(w, serverError, serverError.StatusCode) 152 | default: 153 | resp := &ServerError{Message: err.Error()} 154 | server.responseData(w, resp, http.StatusInternalServerError) 155 | } 156 | } 157 | 158 | func (server *dbServer) responseData(w http.ResponseWriter, data interface{}, statusCode int) { 159 | server.responseHeader(w, statusCode) 160 | 161 | enc := json.NewEncoder(w) 162 | if encodeErr := enc.Encode(data); encodeErr != nil { 163 | server.logger.Error(encodeErr, "failed to write response") 164 | w.WriteHeader(http.StatusInternalServerError) 165 | return 166 | } 167 | } 168 | 169 | func (server *dbServer) responseEmptyBody(w http.ResponseWriter, statusCode int) { 170 | server.responseHeader(w, statusCode) 171 | } 172 | 173 | func (server *dbServer) handleQueryTableOrView( 174 | w http.ResponseWriter, 175 | req *http.Request, 176 | ) { 177 | target := chi.URLParam(req, routeVarTableOrView) 178 | 179 | logger := server.logger.WithValues("target", target, "route", "handleQueryTableOrView") 180 | 181 | qc := NewQueryCompilerFromRequest(req) 182 | selectStmt, err := qc.CompileAsSelect(target) 183 | if err != nil { 184 | logger.Error(err, "parse select query") 185 | server.responseError(w, err) 186 | return 187 | } 188 | logger.V(8).Info(selectStmt.Query) 189 | 190 | rows, err := server.queryer.QueryxContext(req.Context(), selectStmt.Query, selectStmt.Values...) 191 | if err != nil { 192 | logger.Error(err, "query values") 193 | server.responseError(w, err) 194 | return 195 | } 196 | defer rows.Close() 197 | 198 | // make sure return list instead of null for empty list 199 | // FIXME: reflect column type and scan typed value instead of using `interface{}` 200 | rv := make([]map[string]interface{}, 0) 201 | rows.ColumnTypes() 202 | for rows.Next() { 203 | p := make(map[string]interface{}) 204 | if err := rows.MapScan(p); err != nil { 205 | server.responseError(w, err) 206 | return 207 | } 208 | rv = append(rv, p) 209 | } 210 | 211 | responseStatusCode := http.StatusOK 212 | 213 | w.Header().Set("Content-Type", "application/json") // TODO: horner request config 214 | 215 | preference, err := ParsePreferenceFromRequest(req) 216 | if err != nil { 217 | logger.Error(err, "parse preference") 218 | server.responseError(w, err) 219 | return 220 | } 221 | var countTotal string 222 | switch preference.Count { 223 | case countNone: 224 | countTotal = "*" 225 | case countExact: 226 | responseStatusCode = http.StatusPartialContent 227 | 228 | countStmt, err := qc.CompileAsExactCount(target) 229 | if err != nil { 230 | logger.Error(err, "parse count query") 231 | server.responseError(w, err) 232 | return 233 | } 234 | logger.V(8).Info(countStmt.Query) 235 | 236 | var count int64 237 | if err := server.queryer.QueryRowxContext( 238 | req.Context(), 239 | countStmt.Query, countStmt.Values..., 240 | ).Scan(&count); err != nil { 241 | logger.Error(err, "count values") 242 | server.responseError(w, err) 243 | return 244 | } 245 | countTotal = fmt.Sprint(count) 246 | } 247 | 248 | if v := qc.CompileContentRangeHeader(countTotal); v != "" { 249 | w.Header().Set("Range-Unit", "items") 250 | w.Header().Set("Content-Range", v) 251 | } 252 | 253 | server.responseData(w, rv, responseStatusCode) 254 | } 255 | 256 | func (server *dbServer) handleInsertTable( 257 | w http.ResponseWriter, 258 | req *http.Request, 259 | ) { 260 | target := chi.URLParam(req, routeVarTableOrView) 261 | 262 | logger := server.logger.WithValues("target", target, "route", "handleInsertTable") 263 | 264 | qc := NewQueryCompilerFromRequest(req) 265 | insertStmt, err := qc.CompileAsInsert(target) 266 | if err != nil { 267 | logger.Error(err, "parse insert query") 268 | server.responseError(w, err) 269 | return 270 | } 271 | logger.V(8).Info(insertStmt.Query) 272 | 273 | _, err = server.execer.ExecContext(req.Context(), insertStmt.Query, insertStmt.Values...) 274 | if err != nil { 275 | server.responseError(w, err) 276 | return 277 | } 278 | 279 | // TODO: implement support for retrieving object by inserted id 280 | server.responseEmptyBody(w, http.StatusCreated) 281 | } 282 | 283 | func (server *dbServer) handleUpdateTable( 284 | w http.ResponseWriter, 285 | req *http.Request, 286 | ) { 287 | target := chi.URLParam(req, routeVarTableOrView) 288 | 289 | logger := server.logger.WithValues("target", target, "route", "handleUpdateTable") 290 | 291 | qc := NewQueryCompilerFromRequest(req) 292 | updateStmt, err := qc.CompileAsUpdate(target) 293 | if err != nil { 294 | logger.Error(err, "parse update query") 295 | server.responseError(w, err) 296 | return 297 | } 298 | logger.V(8).Info(updateStmt.Query) 299 | 300 | _, err = server.execer.ExecContext(req.Context(), updateStmt.Query, updateStmt.Values...) 301 | if err != nil { 302 | server.responseError(w, err) 303 | return 304 | } 305 | 306 | server.responseEmptyBody(w, http.StatusAccepted) 307 | } 308 | 309 | func (server *dbServer) handleUpdateSingleEntity( 310 | w http.ResponseWriter, 311 | req *http.Request, 312 | ) { 313 | target := chi.URLParam(req, routeVarTableOrView) 314 | 315 | logger := server.logger.WithValues("target", target, "route", "handleUpdateSingleEntity") 316 | 317 | qc := NewQueryCompilerFromRequest(req) 318 | updateStmt, err := qc.CompileAsUpdateSingleEntry(target) 319 | if err != nil { 320 | logger.Error(err, "parse update single entry query") 321 | server.responseError(w, err) 322 | return 323 | } 324 | logger.V(8).Info(updateStmt.Query) 325 | 326 | _, err = server.execer.ExecContext(req.Context(), updateStmt.Query, updateStmt.Values...) 327 | if err != nil { 328 | server.responseError(w, err) 329 | return 330 | } 331 | } 332 | 333 | func (server *dbServer) handleDeleteTable( 334 | w http.ResponseWriter, 335 | req *http.Request, 336 | ) { 337 | target := chi.URLParam(req, routeVarTableOrView) 338 | 339 | logger := server.logger.WithValues("target", target, "route", "handleDeleteTable") 340 | 341 | qc := NewQueryCompilerFromRequest(req) 342 | updateStmt, err := qc.CompileAsDelete(target) 343 | if err != nil { 344 | logger.Error(err, "parse delete query") 345 | server.responseError(w, err) 346 | return 347 | } 348 | logger.V(8).Info(updateStmt.Query) 349 | 350 | _, err = server.execer.ExecContext(req.Context(), updateStmt.Query, updateStmt.Values...) 351 | if err != nil { 352 | server.responseError(w, err) 353 | return 354 | } 355 | 356 | server.responseEmptyBody(w, http.StatusAccepted) 357 | } 358 | 359 | func createServeCmd() *cobra.Command { 360 | serverOpts := new(ServerOptions) 361 | metricsServerOpts := new(MetricsServerOptions) 362 | pprofServerOpts := new(PprofServerOptions) 363 | 364 | cmd := &cobra.Command{ 365 | Use: "serve", 366 | Short: "Start database server", 367 | SilenceUsage: true, 368 | SilenceErrors: true, 369 | RunE: func(cmd *cobra.Command, args []string) error { 370 | logger, err := createLogger(cmd) 371 | if err != nil { 372 | setupLogger.Error(err, "failed to create logger") 373 | return err 374 | } 375 | 376 | db, err := openDB(cmd) 377 | if err != nil { 378 | setupLogger.Error(err, "failed to open db") 379 | return err 380 | } 381 | defer db.Close() 382 | 383 | serverOpts.Logger = logger 384 | serverOpts.Queryer = db 385 | serverOpts.Execer = db 386 | 387 | server, err := NewServer(serverOpts) 388 | if err != nil { 389 | setupLogger.Error(err, "failed to create server") 390 | return err 391 | } 392 | 393 | metricsServerOpts.Logger = logger 394 | metricsServerOpts.Queryer = db 395 | metricsServer, err := NewMetricsServer(*metricsServerOpts) 396 | if err != nil { 397 | setupLogger.Error(err, "failed to create metrics server") 398 | return err 399 | } 400 | 401 | pprofServerOpts.Logger = logger 402 | pprofServer, err := NewPprofServer(*pprofServerOpts) 403 | if err != nil { 404 | setupLogger.Error(err, "failed to create pprof server") 405 | return err 406 | } 407 | 408 | ctx, cancel := context.WithCancel(context.Background()) 409 | defer cancel() 410 | 411 | sigs := make(chan os.Signal, 1) 412 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 413 | 414 | done := ctx.Done() 415 | 416 | go metricsServer.Start(done) 417 | go pprofServer.Start(done) 418 | go server.Start(done) 419 | <-sigs 420 | 421 | return nil 422 | }, 423 | } 424 | 425 | serverOpts.bindCLIFlags(cmd.Flags()) 426 | metricsServerOpts.bindCLIFlags(cmd.Flags()) 427 | pprofServerOpts.bindCLIFlags(cmd.Flags()) 428 | bindDBDSNFlag(cmd.Flags()) 429 | 430 | return cmd 431 | } 432 | -------------------------------------------------------------------------------- /server_auth.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "strings" 7 | 8 | "github.com/golang-jwt/jwt" 9 | "github.com/spf13/pflag" 10 | ) 11 | 12 | const ( 13 | headerNameAuthorizer = "Authorization" 14 | headerPrefixBearer = "Bearer" 15 | ) 16 | 17 | type ServerAuthOptions struct { 18 | RSAPublicKeyFilePath string 19 | TokenFilePath string 20 | 21 | // for unit test 22 | disableAuth bool 23 | } 24 | 25 | func (opts *ServerAuthOptions) bindCLIFlags(fs *pflag.FlagSet) { 26 | fs.StringVar(&opts.RSAPublicKeyFilePath, "auth-rsa-public-key", "", "path to the RSA public key file") 27 | fs.StringVar(&opts.TokenFilePath, "auth-token-file", "", "path to the token file") 28 | } 29 | 30 | func (opts *ServerAuthOptions) defaults() error { 31 | if opts.disableAuth { 32 | return nil 33 | } 34 | 35 | if opts.RSAPublicKeyFilePath == "" && opts.TokenFilePath == "" { 36 | return fmt.Errorf("specifies at least --auth-rsa-public-key or --auth-token-file") 37 | } 38 | 39 | if opts.RSAPublicKeyFilePath != "" && opts.TokenFilePath != "" { 40 | return fmt.Errorf("cannot specific --auth-rsa-public-key and --auth-token-file at the same time") 41 | } 42 | 43 | return nil 44 | } 45 | 46 | func (opts *ServerAuthOptions) createAuthMiddleware( 47 | responseErr func(w http.ResponseWriter, err error), 48 | ) func(http.Handler) http.Handler { 49 | if opts.disableAuth { 50 | return func(next http.Handler) http.Handler { 51 | return next 52 | } 53 | } 54 | 55 | jwtParser := &jwt.Parser{ 56 | ValidMethods: []string{}, 57 | SkipClaimsValidation: false, 58 | } 59 | 60 | jwtKeyFunc := jwt.Keyfunc(func(t *jwt.Token) (interface{}, error) { 61 | return nil, fmt.Errorf("invalid token") 62 | }) 63 | 64 | // NOTE: we re-read token from disk to allow reloading public keys 65 | 66 | switch { 67 | case opts.RSAPublicKeyFilePath != "": 68 | keyReader := readFileWithStatCache(opts.RSAPublicKeyFilePath) 69 | 70 | jwtParser.ValidMethods = append( 71 | jwtParser.ValidMethods, 72 | jwt.SigningMethodRS256.Name, 73 | jwt.SigningMethodRS384.Name, 74 | jwt.SigningMethodRS512.Name, 75 | ) 76 | jwtKeyFunc = func(t *jwt.Token) (interface{}, error) { 77 | b, err := keyReader() 78 | if err != nil { 79 | return nil, err 80 | } 81 | 82 | v, err := jwt.ParseRSAPublicKeyFromPEM(b) 83 | if err != nil { 84 | return nil, err 85 | } 86 | return v, nil 87 | } 88 | case opts.TokenFilePath != "": 89 | tokenReader := readFileWithStatCache(opts.TokenFilePath) 90 | 91 | jwtParser.ValidMethods = append( 92 | jwtParser.ValidMethods, 93 | jwt.SigningMethodHS256.Name, 94 | jwt.SigningMethodHS384.Name, 95 | jwt.SigningMethodHS512.Name, 96 | ) 97 | jwtKeyFunc = func(t *jwt.Token) (interface{}, error) { 98 | b, err := tokenReader() 99 | if err != nil { 100 | return nil, err 101 | } 102 | 103 | return b, nil 104 | } 105 | } 106 | 107 | return func(next http.Handler) http.Handler { 108 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 109 | v := r.Header.Get(headerNameAuthorizer) 110 | 111 | if v == "" { 112 | responseErr(w, ErrUnauthorized.WithHint("missing auth header")) 113 | return 114 | } 115 | 116 | ps := strings.SplitN(v, " ", 2) 117 | if len(ps) != 2 { 118 | responseErr(w, ErrUnauthorized.WithHint("invalid auth header")) 119 | return 120 | } 121 | 122 | if !strings.EqualFold(ps[0], headerPrefixBearer) { 123 | responseErr(w, ErrUnauthorized.WithHint("invalid auth header")) 124 | return 125 | } 126 | 127 | // TODO: add rbac support 128 | _, err := jwtParser.Parse(ps[1], jwtKeyFunc) 129 | if err != nil { 130 | responseErr(w, ErrUnauthorized.WithHint(err.Error())) 131 | return 132 | } 133 | 134 | next.ServeHTTP(w, r) 135 | }) 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /server_errors.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | ) 7 | 8 | type ServerError struct { 9 | Message string `json:"message"` 10 | Code string `json:"code,omitempty"` 11 | Hint string `json:"hint,omitempty"` 12 | StatusCode int `json:"-"` 13 | } 14 | 15 | func (se *ServerError) Error() string { 16 | if se.Hint != "" { 17 | return fmt.Sprintf("%s - %s", se.Message, se.Hint) 18 | } 19 | return se.Message 20 | } 21 | 22 | func (se *ServerError) WithHint(hint string) *ServerError { 23 | rv := new(ServerError) 24 | *rv = *se 25 | rv.Hint = hint 26 | return rv 27 | } 28 | 29 | var ( 30 | ErrUnsupportedMediaType = &ServerError{ 31 | Message: "Unsupported Media Type", 32 | StatusCode: http.StatusUnsupportedMediaType, 33 | } 34 | 35 | ErrBadRequest = &ServerError{ 36 | Message: "Bad Request", 37 | StatusCode: http.StatusBadRequest, 38 | } 39 | 40 | ErrUnauthorized = &ServerError{ 41 | Message: "Unauthorized", 42 | StatusCode: http.StatusUnauthorized, 43 | } 44 | 45 | ErrAccessRestricted = &ServerError{ 46 | Message: "Access Restricted", 47 | StatusCode: http.StatusForbidden, 48 | } 49 | ) 50 | 51 | func ErrUnsupportedOperator(op string) *ServerError { 52 | return &ServerError{ 53 | Message: "Unsupported Operator", 54 | Hint: fmt.Sprintf("operator %q is unsupported", op), 55 | StatusCode: http.StatusBadRequest, 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /server_security.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/go-chi/chi/v5" 7 | "github.com/spf13/pflag" 8 | ) 9 | 10 | // TODO: generally speaking, we need a fine-grained RBAC system. 11 | 12 | type ServerSecurityOptions struct { 13 | // EnabledTableOrViews list of table or view names that are accessible (read & write). 14 | EnabledTableOrViews []string 15 | } 16 | 17 | func (opts *ServerSecurityOptions) bindCLIFlags(fs *pflag.FlagSet) { 18 | fs.StringSliceVar( 19 | &opts.EnabledTableOrViews, 20 | "security-allow-table", 21 | []string{}, 22 | "list of table or view names that are accessible (read & write)", 23 | ) 24 | } 25 | 26 | func (opts *ServerSecurityOptions) defaults() error { 27 | return nil 28 | } 29 | 30 | func (opts *ServerSecurityOptions) createTableOrViewAccessCheckMiddleware( 31 | responseErr func(w http.ResponseWriter, err error), 32 | ) func(http.Handler) http.Handler { 33 | accessibleTableOrViews := make(map[string]struct{}) 34 | for _, t := range opts.EnabledTableOrViews { 35 | accessibleTableOrViews[t] = struct{}{} 36 | } 37 | 38 | return func(next http.Handler) http.Handler { 39 | return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 40 | target := chi.URLParam(req, routeVarTableOrView) 41 | 42 | if _, ok := accessibleTableOrViews[target]; !ok { 43 | responseErr(w, ErrAccessRestricted) 44 | return 45 | } 46 | 47 | next.ServeHTTP(w, req) 48 | }) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /server_utils.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | 7 | "github.com/go-chi/chi/v5/middleware" 8 | "github.com/go-logr/logr" 9 | ) 10 | 11 | type httpLogger struct { 12 | logr.Logger 13 | } 14 | 15 | func (l httpLogger) Print(v ...interface{}) { 16 | l.Info(fmt.Sprint(v...)) 17 | } 18 | 19 | func serverLogger(logr logr.Logger) func(http.Handler) http.Handler { 20 | formatter := &middleware.DefaultLogFormatter{ 21 | Logger: httpLogger{logr}, 22 | } 23 | return middleware.RequestLogger(formatter) 24 | } 25 | -------------------------------------------------------------------------------- /version.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "runtime/debug" 6 | ) 7 | 8 | // ServerVersion defines the server application version. 9 | // Use -ldflags "-X main.ServerVersion=1.0.0" to override the version. 10 | var ServerVersion string 11 | 12 | func loadServerVersionFromBuildInfo() string { 13 | info, ok := debug.ReadBuildInfo() 14 | if !ok { 15 | return "" 16 | } 17 | 18 | var ( 19 | commit string = "unknown" 20 | dirty bool 21 | ) 22 | for _, s := range info.Settings { 23 | switch { 24 | case s.Key == "vcs.revision": 25 | commit = s.Value 26 | if len(s.Value) > 10 { 27 | commit = commit[:10] 28 | } 29 | case s.Key == "vcs.modified": 30 | dirty = s.Value == "true" 31 | } 32 | } 33 | if dirty { 34 | commit += "-dirty" 35 | } 36 | 37 | s := fmt.Sprintf("sqlite-rest/%s (%s, commit/%s)", info.Main.Version, info.GoVersion, commit) 38 | 39 | return s 40 | } 41 | 42 | func setServerVersion() { 43 | if ServerVersion != "" { 44 | return 45 | } 46 | 47 | if v := loadServerVersionFromBuildInfo(); v != "" { 48 | ServerVersion = v 49 | return 50 | } 51 | 52 | ServerVersion = "(devel)" 53 | } 54 | 55 | func init() { 56 | setServerVersion() 57 | } 58 | --------------------------------------------------------------------------------