├── .cursorignore
├── .devcontainer
├── devcontainer.json
└── docker-compose.extend.yml
├── .dockerignore
├── .editorconfig
├── .github
├── dependabot.yml
└── workflows
│ ├── ci.yml
│ └── genesis.yml
├── .gitignore
├── .vscode
├── extensions.json
├── launch.json
└── tasks.json
├── Makefile
├── README.md
├── assets
└── design_doc.md
├── build
└── docker
│ ├── dev
│ ├── Dockerfile
│ ├── docker-compose.yml
│ ├── entrypoint.sh
│ └── mysqlconf
│ │ └── mysql.cnf
│ └── prod
│ ├── Dockerfile
│ ├── Dockerfile.razorpay
│ ├── entrypoint.sh
│ └── probe.sh
├── cmd
├── gateway
│ └── main.go
└── migration
│ └── main.go
├── config
├── default.toml
├── dev-docker.toml
├── prod-dev.toml
└── prod.toml
├── go.mod
├── go.sum
├── internal
├── boot
│ └── boot.go
├── config
│ └── config.go
├── constants
│ └── contextkeys
│ │ └── contextkeys.go
├── frontend
│ ├── README.md
│ ├── components
│ │ ├── adminview.go
│ │ ├── dashboardview.go
│ │ ├── pageview.go
│ │ ├── querylistview.go
│ │ ├── queryview.go
│ │ └── tabview.go
│ ├── core
│ │ └── core.go
│ ├── dispatcher
│ │ └── dispatcher.go
│ ├── main.go
│ └── server
│ │ └── main.go
├── gatewayserver
│ ├── backendApi
│ │ ├── core.go
│ │ ├── server.go
│ │ └── validation.go
│ ├── database
│ │ ├── dbRepo
│ │ │ └── dbRepo.go
│ │ └── migrations
│ │ │ ├── 20210805195304_bootstrap.go
│ │ │ ├── 20211203205304_alter_backends.go
│ │ │ ├── 20220107205304_increase_query_text_size.go
│ │ │ ├── 20240524205304_add_auth_delegation.go
│ │ │ └── 20240525205304_add_set_source.go
│ ├── groupApi
│ │ ├── core.go
│ │ ├── server.go
│ │ └── validation.go
│ ├── healthApi
│ │ ├── core.go
│ │ └── server.go
│ ├── hooks
│ │ ├── auth.go
│ │ ├── ctx.go
│ │ ├── metric.go
│ │ └── requestid.go
│ ├── metrics
│ │ └── metrics.go
│ ├── models
│ │ ├── backend.go
│ │ ├── group.go
│ │ ├── policy.go
│ │ └── query.go
│ ├── policyApi
│ │ ├── core.go
│ │ ├── core_test.go
│ │ ├── server.go
│ │ └── validation.go
│ ├── queryApi
│ │ ├── core.go
│ │ ├── server.go
│ │ └── validation.go
│ └── repo
│ │ ├── backend.go
│ │ ├── group.go
│ │ ├── policy.go
│ │ └── query.go
├── monitor
│ ├── core.go
│ ├── metric.go
│ ├── monitor.go
│ └── trino.go
├── provider
│ └── logger.go
├── router
│ ├── auth.go
│ ├── auth_test.go
│ ├── metric.go
│ ├── request.go
│ ├── request_test.go
│ ├── request_type.go
│ ├── response.go
│ ├── router.go
│ └── trinoheaders
│ │ ├── trino.go
│ │ └── trino_test.go
└── utils
│ ├── utils.go
│ └── utils_test.go
├── pkg
├── config
│ ├── config.go
│ ├── config_test.go
│ └── testdata
│ │ └── default.toml
├── fetcher
│ ├── fetcher.go
│ └── manager.go
├── logger
│ ├── entry.go
│ ├── entry_test.go
│ ├── logger.go
│ └── logger_test.go
└── spine
│ ├── datatype
│ └── validation.go
│ ├── db
│ ├── db.go
│ └── db_test.go
│ ├── errors.go
│ ├── model.go
│ └── repository.go
├── rpc
└── gateway
│ └── service.proto
├── scripts
├── compile.sh
├── coverage.sh
├── docker.sh
├── run-example.sh
└── setup.sh
├── third_party
└── swaggerui
│ ├── favicon-16x16.png
│ ├── favicon-32x32.png
│ ├── index.css
│ ├── index.html
│ ├── oauth2-redirect.html
│ ├── swagger-initializer.js
│ ├── swagger-ui-bundle.js
│ ├── swagger-ui-bundle.js.map
│ ├── swagger-ui-es-bundle-core.js
│ ├── swagger-ui-es-bundle-core.js.map
│ ├── swagger-ui-es-bundle.js
│ ├── swagger-ui-es-bundle.js.map
│ ├── swagger-ui-standalone-preset.js
│ ├── swagger-ui-standalone-preset.js.map
│ ├── swagger-ui.css
│ ├── swagger-ui.css.map
│ ├── swagger-ui.js
│ └── swagger-ui.js.map
├── tools.go
└── web
└── frontend
├── favicon.ico
└── index.html
/.cursorignore:
--------------------------------------------------------------------------------
1 | # Distribution and Environment
2 | dist/*
3 | build/*
4 | venv/*
5 | env/*
6 | *.env
7 | .env.*
8 | virtualenv/*
9 | .python-version
10 | .ruby-version
11 | .node-version
12 |
13 | # Logs and Temporary Files
14 | *.log
15 | *.tsv
16 | *.csv
17 | *.txt
18 | tmp/*
19 | temp/*
20 | .tmp/*
21 | *.temp
22 | *.cache
23 | .cache/*
24 | logs/*
25 |
26 | # Sensitive Data
27 | *.json
28 | *.xml
29 | *.yml
30 | *.yaml
31 | *.properties
32 | properties.json
33 | *.sqlite
34 | *.sqlite3
35 | *.dbsql
36 | secrets.*
37 | *secret*
38 | *password*
39 | *credential*
40 | .npmrc
41 | .yarnrc
42 | .aws/*
43 | .config/*
44 |
45 | # Credentials and Keys
46 | *.pem
47 | *.ppk
48 | *.key
49 | *.pub
50 | *.p12
51 | *.pfx
52 | *.htpasswd
53 | *.keystore
54 | *.jks
55 | *.truststore
56 | *.cer
57 | id_rsa*
58 | known_hosts
59 | authorized_keys
60 | .ssh/*
61 | .gnupg/*
62 | .pgpass
63 |
64 | # Config Files
65 | *.conf
66 | *.toml
67 | *.ini
68 | .env.local
69 | .env.development
70 | .env.test
71 | .env.production
72 | config/*
73 |
74 | # Documentation and Notes
75 | *.md
76 | *.mdx
77 | *.rst
78 | *.txt
79 | docs/*
80 | README*
81 | CHANGELOG*
82 | LICENSE*
83 | CONTRIBUTING*
84 |
85 | # Database Files
86 | *.sql
87 | *.db
88 | *.dmp
89 | *.dump
90 | *.backup
91 | *.restore
92 | *.mdb
93 | *.accdb
94 | *.realm*
95 |
96 | # Backup and Archive Files
97 | *.bak
98 | *.backup
99 | *.swp
100 | *.swo
101 | *.swn
102 | *~
103 | *.old
104 | *.orig
105 | *.archive
106 | *.gz
107 | *.zip
108 | *.tar
109 | *.rar
110 | *.7z
111 |
112 | # Compiled and Binary Files
113 | *.pyc
114 | *.pyo
115 | **/__pycache__/**
116 | *.class
117 | *.jar
118 | *.war
119 | *.ear
120 | *.dll
121 | *.exe
122 | *.so
123 | *.dylib
124 | *.bin
125 | *.obj
126 |
127 | # IDE and Editor Files
128 | .idea/*
129 | *.iml
130 | .vscode/*
131 | .project
132 | .classpath
133 | .settings/*
134 | *.sublime-*
135 | .atom/*
136 | .eclipse/*
137 | *.code-workspace
138 | .history/*
139 |
140 | # Build and Dependency Directories
141 | node_modules/*
142 | bower_components/*
143 | vendor/*
144 | packages/*
145 | jspm_packages/*
146 | .gradle/*
147 | target/*
148 | out/*
149 |
150 | # Testing and Coverage Files
151 | coverage/*
152 | .coverage
153 | htmlcov/*
154 | .pytest_cache/*
155 | .tox/*
156 | junit.xml
157 | test-results/*
158 |
159 | # Mobile Development
160 | *.apk
161 | *.aab
162 | *.ipa
163 | *.xcarchive
164 | *.provisionprofile
165 | google-services.json
166 | GoogleService-Info.plist
167 |
168 | # Certificate and Security Files
169 | *.crt
170 | *.csr
171 | *.ovpn
172 | *.p7b
173 | *.p7s
174 | *.pfx
175 | *.spc
176 | *.stl
177 | *.pem.crt
178 | ssl/*
179 |
180 | # Container and Infrastructure
181 | *.tfstate
182 | *.tfstate.backup
183 | .terraform/*
184 | .vagrant/*
185 | docker-compose.override.yml
186 | kubernetes/*
187 |
188 | # Design and Media Files (often large and binary)
189 | *.psd
190 | *.ai
191 | *.sketch
192 | *.fig
193 | *.xd
194 | assets/raw/*
195 |
--------------------------------------------------------------------------------
/.devcontainer/devcontainer.json:
--------------------------------------------------------------------------------
1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the
2 | // README at: https://github.com/devcontainers/templates/tree/main/src/docker-outside-of-docker-compose
3 | {
4 | "name": "Existing Docker Compose (Extend)",
5 |
6 | // Update the 'dockerComposeFile' list if you have more compose files or use different names.
7 | // The .devcontainer/docker-compose.yml file contains any overrides you need/want to make.
8 | "dockerComposeFile": [
9 | "../build/docker/dev/docker-compose.yml",
10 | "docker-compose.extend.yml"
11 | ],
12 |
13 | // The 'service' property is the name of the service for the container that VS Code should
14 | // use. Update this value and .devcontainer/docker-compose.yml to the real service name.
15 | "service": "trino_gateway",
16 |
17 | // The optional 'workspaceFolder' property is the path VS Code should open by default when
18 | // connected. This is typically a file mount in .devcontainer/docker-compose.yml
19 | "workspaceFolder": "/app",
20 | // Use 'forwardPorts' to make a list of ports inside the container available locally.
21 | // "forwardPorts": [],
22 |
23 | // Use 'postCreateCommand' to run commands after the container is created.
24 | // "postCreateCommand": "docker --version",
25 |
26 | // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
27 | "remoteUser": "root",
28 | "customizations": {
29 | "vscode": {
30 | "extensions": [
31 | "golang.go",
32 | "zxh404.vscode-proto3",
33 | "be5invis.toml"
34 | ]
35 | }
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/.devcontainer/docker-compose.extend.yml:
--------------------------------------------------------------------------------
1 | version: '3'
2 | services:
3 | trino_gateway:
4 | # [Optional] Required for ptrace-based debuggers like C++, Go, and Rust
5 | cap_add:
6 | - SYS_PTRACE
7 | security_opt:
8 | - seccomp:unconfined
9 | # Overrides default command so things don't shut down after the process ends.
10 | entrypoint: /bin/sh -c "while sleep 1000; do :; done"
11 |
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | .git/
2 |
3 | Dockerfile*
4 |
5 | .dockerignore
6 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | ; http://editorconfig.org/
2 |
3 | root = true
4 |
5 | [*]
6 | end_of_line = lf
7 | insert_final_newline = true
8 | charset = utf-8
9 | trim_trailing_whitespace = true
10 |
11 | [*.go]
12 | indent_style = tab
13 | indent_size = 4
14 |
15 | [*.md]
16 | trim_trailing_whitespace = false
17 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: gomod
4 | directory: "/"
5 | schedule:
6 | interval: daily
7 | time: "04:00"
8 | timezone: Asia/Calcutta
9 |
10 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 |
8 | jobs:
9 | build-public-image:
10 | runs-on: ubuntu-latest
11 | steps:
12 | -
13 | name: Set up QEMU
14 | uses: docker/setup-qemu-action@v1
15 | -
16 | name: Set up Docker Buildx
17 | uses: docker/setup-buildx-action@v2
18 | -
19 | name: Login to DockerHub
20 | uses: docker/login-action@v2
21 | with:
22 | username: ${{ secrets.PUBLIC_DOCKER_USERNAME }}
23 | password: ${{ secrets.PUBLIC_DOCKER_PASSWORD }}
24 | -
25 | name: Build and push
26 | id: docker_build
27 | uses: docker/build-push-action@v2
28 | with:
29 | file: build/docker/prod/Dockerfile
30 | push: true
31 | tags: razorpay/presto_gateway:${{ github.sha }}
32 | build-args: GIT_COMMIT_HASH=${{ github.sha }}
33 | -
34 | name: Image digest
35 | run: echo ${{ steps.docker_build.outputs.digest }}
36 |
37 | # Rzp image is built from datahub
38 |
--------------------------------------------------------------------------------
/.github/workflows/genesis.yml:
--------------------------------------------------------------------------------
1 | name: Quality Checks
2 | on:
3 | schedule:
4 | - cron: "0 17 * * *"
5 | jobs:
6 | Analysis:
7 | uses: razorpay/genesis/.github/workflows/quality-checks.yml@master
8 | secrets: inherit
9 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # JetBrains project folders.
2 | .idea/
3 |
4 | # VSCode test folder
5 | .vscode-test/*
6 |
7 | # Desktop Services Store - Mac.
8 | .DS_Store
9 |
10 | # Vendor modules.
11 | vendor/*
12 |
13 | # App Binaries.
14 | bin/*
15 |
16 | # Documentation. Use `make docs` to generate
17 | docs/*
18 |
19 | # Ignores all mock/* directories in project.
20 | # These are created by mockgen tool and only servers unit tests when they run.
21 | mock/
22 |
23 | # Dont ignore any .gitkeep files, please.
24 | !*.gitkeep
25 |
26 | .tmp
27 |
28 | # This needs to be sourced from the proto3 repo.
29 | proto
30 |
31 | # Generated protobuf files.
32 | rpc/*/*.pb.go
33 | rpc/*/*.twirp.go
34 |
35 | # Generated or local configuration files.
36 | config/dev.*
37 |
38 | # docker-compose .dev file in deployments.
39 | # deployments/docker-compose.dev.yml
40 |
41 |
42 | # Gopherjs compiled files
43 | web/frontend/js/*
44 |
45 | # Generated swagger defs
46 | third_party/swaggerui/rpc/gateway/service.swagger.json
47 |
--------------------------------------------------------------------------------
/.vscode/extensions.json:
--------------------------------------------------------------------------------
1 | {
2 | "recommendations": ["golang.go", "ms-vscode-remote.remote-containers", "be5invis.toml", "zxh404.vscode-proto3", "ms-azuretools.vscode-docker"]
3 | }
4 |
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "type": "chrome",
9 | "request": "launch",
10 | "name": "Launch gateway GUI in Chrome",
11 | "url": "http://localhost:8000",
12 | "webRoot": "${workspaceFolder}"
13 | }
14 | ]
15 | }
16 |
--------------------------------------------------------------------------------
/.vscode/tasks.json:
--------------------------------------------------------------------------------
1 | {
2 | // See https://go.microsoft.com/fwlink/?LinkId=733558
3 | // for the documentation about the tasks.json format
4 | "version": "2.0.0",
5 | "tasks": [
6 | {
7 | "label": "serve on local",
8 | "type": "shell",
9 | "command": "go run ./cmd/gateway | jq",
10 | "problemMatcher": []
11 | },
12 | {
13 | "label": "serve on local without jq",
14 | "type": "shell",
15 | "command": "go run ./cmd/gateway",
16 | "problemMatcher": []
17 | },
18 | {
19 | "label": "send query submission request to localhost:8080",
20 | "type": "shell",
21 | "command": "curl -X POST http://localhost:8080/v1/statement -H 'X-Trino-User: dev' -d 'SELECT 1' ",
22 | "problemMatcher": []
23 | },
24 | {
25 | "label": "Run local-dev setup example",
26 | "type": "shell",
27 | "command": "source ./scripts/run-example.sh",
28 | "problemMatcher": []
29 | }
30 | ]
31 | }
32 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | SCRIPT_DIR := "./scripts"
2 |
3 | BUILD_OUT_DIR := "bin/"
4 |
5 | PWD = $(shell pwd)
6 |
7 | .PHONY: setup
8 | setup:
9 | $(SCRIPT_DIR)/setup.sh
10 |
11 | .PHONY: dev-setup
12 | dev-setup: setup config/dev.toml
13 |
14 |
15 | config/dev.toml:
16 | touch $(PWD)/config/dev.toml
17 |
18 | .PHONY: build
19 | build:
20 | $(SCRIPT_DIR)/compile.sh
21 |
22 | .PHONY: build-frontend
23 | build-frontend: web/frontend/js/frontend.js
24 |
25 | web/frontend/js/frontend.js:
26 | echo "Compiling frontend"
27 | gopherjs build ./internal/frontend --output "./web/frontend/js/frontend.js" --minify --verbose
28 |
29 | # .PHONY: dev-build
30 | # dev-build:
31 | # $(SCRIPT_DIR)/dev.sh
32 |
33 | .PHONY: dev-docker-up
34 | dev-docker-up:
35 | $(SCRIPT_DIR)/docker.sh up
36 |
37 | .PHONY: dev-docker-down
38 | dev-docker-down:
39 | $(SCRIPT_DIR)/docker.sh down
40 |
41 | .PHONY: dev-docker-run-example ## Runs bundled example in dev docker env
42 | dev-docker-run-example:
43 | $(SCRIPT_DIR)/run-example.sh
44 |
45 | .PHONY: dev-migration
46 | dev-migration:
47 | go build ./cmd/migration/main.go -o migration.go
48 | ./migration.go up
49 |
50 | .PHONY: test-integration
51 | test-integration:
52 | go test -tags=integration ./test/it -v -count=1
53 |
54 | .PHONY: test-unit
55 | test-unit:
56 | go test
57 |
--------------------------------------------------------------------------------
/assets/design_doc.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/razorpay/trino-gateway/3f8ceab54b91b8c2b64a9339fb4c82c4432d6655/assets/design_doc.md
--------------------------------------------------------------------------------
/build/docker/dev/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM golang:1.22.3-alpine3.20
2 |
3 | WORKDIR /app
4 |
5 |
6 | RUN apk update \
7 | && apk add --no-cache bash make protobuf protobuf-dev git gzip curl build-base
8 |
9 | # COPY ./ /app
10 | # RUN go mod download
11 |
12 | RUN go install github.com/githubnemo/CompileDaemon@v1.4.0
13 |
14 | COPY ./go.mod /app/go.mod
15 | COPY ./go.sum /app/go.sum
16 |
17 | # RUN go mod vendor
18 | RUN go mod download
19 |
20 | ENTRYPOINT /app/build/docker/dev/entrypoint.sh
21 |
--------------------------------------------------------------------------------
/build/docker/dev/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '3'
2 | services:
3 | trino_gateway:
4 | build:
5 | context: ./../../..
6 | dockerfile: build/docker/dev/Dockerfile
7 | image: trino_gateway
8 | # image: utkarshsaxena/utk_trino_gateway:v6
9 | container_name: trino_gateway
10 | volumes:
11 | - ./../../..:/app
12 | environment:
13 | APP_ENV: dev-docker
14 | TRINO-GATEWAY_AUTH_ROUTER_AUTHENTICATE: true
15 | entrypoint: /app/build/docker/dev/entrypoint.sh
16 | # entrypoint: ["tail", "-f", "/dev/null"]
17 | expose:
18 | - "8000"
19 | - "8001"
20 | - "8002"
21 | - "8080"
22 | - "8081"
23 | ports:
24 | - 28000:8000
25 | - 28001:8001
26 | - 28002:8002
27 | - 28080:8080
28 | - 28081:8081
29 | networks:
30 | - default
31 | links:
32 | - trino_gateway_mysql
33 | trino_gateway_mysql:
34 | image: mysql:8.0.28-oracle
35 | container_name: trino_gateway_mysql
36 | volumes:
37 | - ./mysqlconf:/etc/mysql/conf.d
38 | ports:
39 | - 33306:3306
40 | environment:
41 | MYSQL_ROOT_PASSWORD: root123
42 | MYSQL_DATABASE: trino-gateway
43 | networks:
44 | - default
45 | # networks:
46 | # trino:
47 | # external: true
48 |
--------------------------------------------------------------------------------
/build/docker/dev/entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | initialize() {
4 | # Init app secrets + envvars
5 | echo "Syncing app deps, if this takes time, update deps in the built image"
6 | make setup build
7 | }
8 |
9 | check_db_connection() {
10 | # Wait till db is available
11 | connected=0
12 | counter=0
13 |
14 | echo "Wait 60 seconds for connection to MySQL"
15 | while [[ ${counter} -lt 60 ]]; do
16 | {
17 | echo "Connecting to MySQL" && go run ./cmd/migration/main.go version &&
18 | connected=1
19 |
20 | } || {
21 | let counter=$counter+3
22 | sleep 3
23 | }
24 | if [[ ${connected} -eq 1 ]]; then
25 | echo "Connected"
26 | break;
27 | fi
28 | done
29 |
30 | if [[ ${connected} -eq 0 ]]; then
31 | echo "MySQL connection failed."
32 | exit;
33 | fi
34 | }
35 |
36 | db_migrations() {
37 | go run ./cmd/migration/main.go up
38 | }
39 |
40 | initialize
41 | check_db_connection
42 | # run db migrations
43 | db_migrations
44 |
45 | # run app
46 |
47 | # CompileDaemon -polling-interval=10 -exclude-dir=.git -exclude-dir=vendor --build="gopherjs build ./internal/frontend/main --output "./web/frontend/js/frontend.js" --verbose && go build cmd/gateway/main.go -o gateway" --command=./gateway
48 | go run ./cmd/gateway/main.go
49 | # tail -f /dev/null
50 |
--------------------------------------------------------------------------------
/build/docker/dev/mysqlconf/mysql.cnf:
--------------------------------------------------------------------------------
1 | [mysqld]
2 | skip-host-cache
3 | skip-name-resolve
4 | general_log=1
5 | general_log_file=/var/lib/mysql/general.log
6 |
--------------------------------------------------------------------------------
/build/docker/prod/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM golang:1.22.3-alpine3.20
2 |
3 | ARG GIT_COMMIT_HASH
4 | ENV GIT_COMMIT_HASH=${GIT_COMMIT_HASH}
5 | ENV TRINO-GATEWAY_APP_GITCOMMITHASH=${GIT_COMMIT_HASH}
6 |
7 | WORKDIR /app
8 |
9 |
10 | RUN apk update \
11 | && apk add --no-cache bash make protobuf protobuf-dev git gzip curl build-base
12 |
13 | COPY ./ /app
14 |
15 | RUN go mod download \
16 | && make setup build
17 | # RUN go mod vendor
18 |
19 | COPY ./build/docker/prod/probe.sh /app/probe.sh
20 |
21 | ENTRYPOINT /app/build/docker/prod/entrypoint.sh
22 |
--------------------------------------------------------------------------------
/build/docker/prod/Dockerfile.razorpay:
--------------------------------------------------------------------------------
1 | ## Dockerfile used for Rzp deployment
2 |
3 | FROM c.rzp.io/razorpay/onggi-multi-arch:rzp-golden-image-base-golang-1.22
4 | # TODO: c.rzp.io/razorpay/rzp-docker-image-inventory-multi-arch:rzp-golden-image-base-golang-1.22
5 |
6 | ARG GIT_COMMIT_HASH
7 | ENV GIT_COMMIT_HASH=${GIT_COMMIT_HASH}
8 | ENV TRINO-GATEWAY_APP_GITCOMMITHASH=${GIT_COMMIT_HASH}
9 |
10 | WORKDIR /app
11 |
12 |
13 | RUN apk update \
14 | && apk add --no-cache bash make protobuf protobuf-dev git gzip curl build-base
15 |
16 | COPY ./ /app
17 |
18 | RUN go mod download \
19 | && make setup build
20 | # RUN go mod vendor
21 |
22 | COPY ./build/docker/prod/probe.sh /app/probe.sh
23 |
24 | ENTRYPOINT /app/build/docker/prod/entrypoint.sh
25 |
--------------------------------------------------------------------------------
/build/docker/prod/entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | initialize() {
4 | # Init app secrets + envvars
5 | echo "Initializing app"
6 | }
7 |
8 | check_db_connection() {
9 | # Wait till db is available
10 | connected=0
11 | counter=0
12 |
13 | echo "Wait 60 seconds for connection to MySQL"
14 | while [[ ${counter} -lt 60 ]]; do
15 | {
16 | echo "Connecting to MySQL" && go run ./cmd/migration/main.go version &&
17 | connected=1
18 |
19 | } || {
20 | let counter=$counter+3
21 | sleep 3
22 | }
23 | if [[ ${connected} -eq 1 ]]; then
24 | echo "Connected"
25 | break;
26 | fi
27 | done
28 |
29 | if [[ ${connected} -eq 0 ]]; then
30 | echo "MySQL connection failed."
31 | exit;
32 | fi
33 | }
34 |
35 | db_migrations() {
36 | go run ./cmd/migration/main.go up
37 | }
38 |
39 | initialize
40 | check_db_connection
41 | # run db migrations
42 | db_migrations
43 |
44 | # run app
45 |
46 | # CompileDaemon -polling-interval=10 -exclude-dir=.git -exclude-dir=vendor --build="gopherjs build ./internal/frontend/main --output "./web/frontend/js/frontend.js" --verbose && go build cmd/gateway/main.go -o gateway" --command=./gateway
47 | go run ./cmd/gateway/main.go
48 | # tail -f /dev/null
49 |
--------------------------------------------------------------------------------
/build/docker/prod/probe.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #validating the status code of health check api and returning 1 in case of non 2xx status code
3 | if [ "$(curl -s -o /dev/null -H 'Content-Type: application/json' -d '{"service": ""}' -w '%{http_code}' http://localhost:8000/twirp/razorpay.gateway.HealthCheckAPI/Check)" != 200 ];
4 | then
5 | exit 1;
6 | fi
7 |
--------------------------------------------------------------------------------
/cmd/migration/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "flag"
6 | "fmt"
7 | "log"
8 | "os"
9 |
10 | "github.com/pressly/goose/v3"
11 | "github.com/razorpay/trino-gateway/internal/boot"
12 | _ "github.com/razorpay/trino-gateway/internal/gatewayserver/database/migrations"
13 | )
14 |
15 | var (
16 | flags = flag.NewFlagSet("goose", flag.ExitOnError)
17 | dir = flags.String("dir", "internal/gatewayserver/database/migrations", "Directory with migration files")
18 | verbose = flags.Bool("v", false, "Enable verbose mode")
19 | )
20 |
21 | func main() {
22 | environ := boot.GetEnv()
23 |
24 | flags.Usage = usage
25 | if err := flags.Parse(os.Args[1:]); err != nil {
26 | log.Fatalf("error parsing flags: %v", err)
27 | }
28 | args := flags.Args()
29 | if *verbose {
30 | goose.SetVerbose(true)
31 | }
32 |
33 | // I.e. no command provided, hence print usage and return.
34 | if len(args) < 1 {
35 | flags.Usage()
36 | return
37 | }
38 |
39 | // Prepares command and arguments for goose's run.
40 | command := args[0]
41 | arguments := []string{}
42 | if len(args) > 1 {
43 | arguments = append(arguments, args[1:]...)
44 | }
45 |
46 | // If command is create or fix, no need to connect to db and hence the
47 | // specific case handling.
48 | switch command {
49 | case "create":
50 | if err := goose.Run("create", nil, *dir, arguments...); err != nil {
51 | log.Fatalf("failed to run command: %v", err)
52 | }
53 | return
54 | case "fix":
55 | if err := goose.Run("fix", nil, *dir); err != nil {
56 | log.Fatalf("failed to run command: %v", err)
57 | }
58 | return
59 | }
60 |
61 | // For other commands boot application (hence getting db and config ready).
62 | // Read application's dialect and get sqldb instance.
63 | if err := boot.InitMigration(context.Background(), environ); err != nil {
64 | log.Fatalf("failed to run command: %v", err)
65 | }
66 |
67 | dialect := boot.Config.Db.Dialect
68 | if err := goose.SetDialect(dialect); err != nil {
69 | log.Fatalf("failed to run command: %v", err)
70 | }
71 | sqldb, err := boot.DB.Instance(context.Background()).DB()
72 | if err != nil {
73 | log.Fatalf("failed to run command: %v", err)
74 | }
75 |
76 | // Finally, executes the goose's command.
77 | if err := goose.Run(command, sqldb, *dir, arguments...); err != nil {
78 | log.Fatalf("failed to run command: %v", err)
79 | }
80 |
81 | }
82 |
83 | func usage() {
84 | flags.PrintDefaults()
85 | fmt.Println(usageCommands)
86 | }
87 |
88 | var (
89 | usageCommands = `
90 | Commands:
91 | up Migrate the DB to the most recent version available
92 | up-to VERSION Migrate the DB to a specific VERSION
93 | down Roll back the version by 1
94 | down-to VERSION Roll back to a specific VERSION
95 | redo Re-run the latest migration
96 | reset Roll back all migrations
97 | status Dump the migration status for the current DB
98 | version Print the current version of the database
99 | create NAME Creates new migration file with the current timestamp
100 | fix Apply sequential ordering to migrations
101 | `
102 | )
103 |
--------------------------------------------------------------------------------
/config/default.toml:
--------------------------------------------------------------------------------
1 | [app]
2 | env = "default"
3 | gitCommitHash = "nil"
4 | logLevel = "info"
5 | metricsPort = 8002
6 | # gui & twirp app need to be on same port for now, check frontend README for more details
7 | # guiPort = 8000
8 | port = 8000
9 | serviceExternalHostname = "localhost:8080"
10 | serviceHostname = "$$internalHost"
11 | serviceName = "trino-gateway"
12 | shutdownDelay = 2
13 | shutdownTimeout = 5
14 |
15 | [db]
16 | [db.ConnectionConfig]
17 | dialect = "mysql"
18 | protocol = "tcp"
19 | url = "localhost"
20 | port = 33306
21 | username = "root"
22 | password = "root123"
23 | sslMode = "require"
24 | name = "trino-gateway"
25 | [db.ConnectionPoolConfig]
26 | maxOpenConnections = 5
27 | maxIdleConnections = 5
28 | connectionMaxLifetime = 0
29 |
30 | [auth]
31 | token = "test123"
32 | tokenHeaderKey = "X-Auth-Key"
33 | [auth.router.delegatedAuth]
34 | validationProviderURL = "localhost:28001"
35 | validationProviderToken = "test123"
36 | cacheTTLMinutes = "10m"
37 |
38 |
39 |
40 |
41 | [gateway]
42 | ports = [8080, 8081]
43 | defaultRoutingGroup = "adhoc"
44 | # empty will mean 0.0.0.0 which is required only if running inside docker container, set to `localhost` otherwise
45 | network = ""
46 |
47 | [monitor]
48 | interval = "10m"
49 | statsValiditySecs = 0
50 | healthCheckSql = "SELECT 1"
51 | [monitor.trino]
52 | user = "trino-gateway"
53 | password = ""
54 |
--------------------------------------------------------------------------------
/config/dev-docker.toml:
--------------------------------------------------------------------------------
1 | [app]
2 | gitCommitHash = "dev-docker"
3 | logLevel = "debug"
4 |
5 | [db]
6 | [db.ConnectionConfig]
7 | url = "trino_gateway_mysql"
8 | port = 3306
9 |
10 | [gateway]
11 | defaultRoutingGroup = "dev"
12 |
--------------------------------------------------------------------------------
/config/prod-dev.toml:
--------------------------------------------------------------------------------
1 | [app]
2 | logLevel = "info"
3 | [monitor]
4 | interval = "20s"
5 | healthCheckSql = "SHOW SCHEMAS FROM hive"
6 | [gateway]
7 | ports = [8080, 8081, 8082, 8083]
8 |
--------------------------------------------------------------------------------
/config/prod.toml:
--------------------------------------------------------------------------------
1 | [app]
2 | logLevel = "warn"
3 | [monitor]
4 | interval = "20s"
5 | healthCheckSql = "SHOW SCHEMAS FROM hive"
6 | [gateway]
7 | ports = [8080, 8081, 8082, 8083, 8084, 8085, 8086, 8087, 8088]
8 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/razorpay/trino-gateway
2 |
3 | go 1.22
4 |
5 | require (
6 | github.com/DATA-DOG/go-sqlmock v1.5.2
7 | github.com/NYTimes/gziphandler v1.1.1
8 | github.com/dlmiddlecote/sqlstats v1.0.2
9 | github.com/fatih/structs v1.1.0
10 | github.com/go-co-op/gocron v1.35.0 // v1.35.0+ are broken for v1
11 | github.com/go-ozzo/ozzo-validation v3.6.0+incompatible
12 | github.com/go-ozzo/ozzo-validation/v4 v4.3.0
13 | github.com/go-sql-driver/mysql v1.8.1
14 | github.com/gobuffalo/nulls v0.4.2
15 | github.com/golang/protobuf v1.5.4
16 | github.com/gopherjs/gopherjs v1.19.0-beta1
17 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0
18 | github.com/hexops/vecty v0.6.0
19 | github.com/lib/pq v1.10.9
20 | github.com/pressly/goose/v3 v3.20.0
21 | github.com/prometheus/client_golang v1.19.1
22 | github.com/robfig/cron/v3 v3.0.1
23 | github.com/rs/xid v1.5.0
24 | github.com/spf13/viper v1.18.2
25 | github.com/stretchr/testify v1.9.0
26 | github.com/trinodb/trino-go-client v0.315.0
27 | github.com/twitchtv/twirp v8.1.3+incompatible
28 | go.uber.org/zap v1.27.0
29 | google.golang.org/protobuf v1.34.1
30 | gorm.io/driver/mysql v1.1.2
31 | gorm.io/driver/postgres v1.1.2
32 | gorm.io/gorm v1.21.16
33 | gorm.io/plugin/dbresolver v1.1.0
34 | )
35 |
36 | require (
37 | filippo.io/edwards25519 v1.1.0 // indirect
38 | github.com/beorn7/perks v1.0.1 // indirect
39 | github.com/cespare/xxhash/v2 v2.3.0 // indirect
40 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
41 | github.com/evanw/esbuild v0.21.4 // indirect
42 | github.com/fsnotify/fsnotify v1.7.0 // indirect
43 | github.com/gofrs/uuid v4.4.0+incompatible // indirect
44 | github.com/google/uuid v1.6.0 // indirect
45 | github.com/hashicorp/go-uuid v1.0.3 // indirect
46 | github.com/hashicorp/hcl v1.0.0 // indirect
47 | github.com/inconshreveable/mousetrap v1.1.0 // indirect
48 | github.com/jackc/chunkreader/v2 v2.0.1 // indirect
49 | github.com/jackc/pgconn v1.14.3 // indirect
50 | github.com/jackc/pgio v1.0.0 // indirect
51 | github.com/jackc/pgpassfile v1.0.0 // indirect
52 | github.com/jackc/pgproto3/v2 v2.3.3 // indirect
53 | github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
54 | github.com/jackc/pgtype v1.14.3 // indirect
55 | github.com/jackc/pgx/v4 v4.18.3 // indirect
56 | github.com/jcmturner/gofork v1.7.6 // indirect
57 | github.com/jinzhu/inflection v1.0.0 // indirect
58 | github.com/jinzhu/now v1.1.5 // indirect
59 | github.com/magiconair/properties v1.8.7 // indirect
60 | github.com/mfridman/interpolate v0.0.2 // indirect
61 | github.com/mitchellh/mapstructure v1.5.0 // indirect
62 | github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86 // indirect
63 | github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c // indirect
64 | github.com/pelletier/go-toml/v2 v2.2.2 // indirect
65 | github.com/pkg/errors v0.9.1 // indirect
66 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
67 | github.com/prometheus/client_model v0.6.1 // indirect
68 | github.com/prometheus/common v0.53.0 // indirect
69 | github.com/prometheus/procfs v0.15.0 // indirect
70 | github.com/sagikazarmark/locafero v0.4.0 // indirect
71 | github.com/sagikazarmark/slog-shim v0.1.0 // indirect
72 | github.com/sethvargo/go-retry v0.2.4 // indirect
73 | github.com/sirupsen/logrus v1.9.3 // indirect
74 | github.com/sourcegraph/conc v0.3.0 // indirect
75 | github.com/spf13/afero v1.11.0 // indirect
76 | github.com/spf13/cast v1.6.0 // indirect
77 | github.com/spf13/cobra v1.8.0 // indirect
78 | github.com/spf13/pflag v1.0.5 // indirect
79 | github.com/subosito/gotenv v1.6.0 // indirect
80 | github.com/visualfc/goembed v0.3.3 // indirect
81 | go.uber.org/atomic v1.11.0 // indirect
82 | go.uber.org/multierr v1.11.0 // indirect
83 | golang.org/x/crypto v0.23.0 // indirect
84 | golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d // indirect
85 | golang.org/x/sync v0.7.0 // indirect
86 | golang.org/x/sys v0.20.0 // indirect
87 | golang.org/x/term v0.20.0 // indirect
88 | golang.org/x/text v0.15.0 // indirect
89 | golang.org/x/tools v0.21.0 // indirect
90 | google.golang.org/genproto/googleapis/api v0.0.0-20240521202816-d264139d666e // indirect
91 | google.golang.org/genproto/googleapis/rpc v0.0.0-20240521202816-d264139d666e // indirect
92 | google.golang.org/grpc v1.64.0 // indirect
93 | gopkg.in/ini.v1 v1.67.0 // indirect
94 | gopkg.in/jcmturner/aescts.v1 v1.0.1 // indirect
95 | gopkg.in/jcmturner/dnsutils.v1 v1.0.1 // indirect
96 | gopkg.in/jcmturner/gokrb5.v6 v6.1.1 // indirect
97 | gopkg.in/jcmturner/rpc.v1 v1.1.0 // indirect
98 | gopkg.in/yaml.v3 v3.0.1 // indirect
99 | )
100 |
--------------------------------------------------------------------------------
/internal/boot/boot.go:
--------------------------------------------------------------------------------
1 | package boot
2 |
3 | import (
4 | "context"
5 | "log"
6 | "os"
7 | "strings"
8 |
9 | "github.com/dlmiddlecote/sqlstats"
10 | "github.com/fatih/structs"
11 | "github.com/prometheus/client_golang/prometheus"
12 | "github.com/razorpay/trino-gateway/internal/config"
13 | "github.com/razorpay/trino-gateway/internal/constants/contextkeys"
14 | config_reader "github.com/razorpay/trino-gateway/pkg/config"
15 | "github.com/razorpay/trino-gateway/pkg/logger"
16 | "github.com/razorpay/trino-gateway/pkg/spine/db"
17 | "github.com/rs/xid"
18 | )
19 |
20 | const (
21 | requestIDHttpHeaderKey = "X-Request-ID"
22 | requestIDCtxKey = "RequestID"
23 | )
24 |
25 | var (
26 | // Config contains application configuration values.
27 | Config config.Config
28 |
29 | // DB holds the application db connection.
30 | DB *db.DB
31 | )
32 |
33 | func init() {
34 | // Init config
35 | err := config_reader.NewDefaultConfig().Load(GetEnv(), &Config)
36 | if err != nil {
37 | log.Fatal(err)
38 | }
39 |
40 | InitLogger(context.Background())
41 |
42 | // Init Db
43 | DB, err = db.NewDb(&Config.Db)
44 | if err != nil {
45 | log.Fatal(err.Error())
46 | }
47 | }
48 |
49 | // Fetch env for bootstrapping
50 | func GetEnv() string {
51 | environment := os.Getenv("APP_ENV")
52 | if environment == "" {
53 | log.Print("APP_ENV not set defaulting to dev env.", environment)
54 | environment = "dev"
55 | }
56 |
57 | log.Print("Setting app env to ", environment)
58 |
59 | return environment
60 | }
61 |
62 | // GetRequestID gets the request id
63 | // if its already set in the given context
64 | // if there is no requestID set then it'll create a new
65 | // request id and returns the same
66 | func GetRequestID(ctx context.Context) string {
67 | if val, ok := ctx.Value(contextkeys.RequestID).(string); ok {
68 | return val
69 | }
70 | return xid.New().String()
71 | }
72 |
73 | // WithRequestID adds a request if to the context and gives the updated context back
74 | // if the passed requestID is empty then creates one by itself
75 | func WithRequestID(ctx context.Context, requestID string) context.Context {
76 | if requestID == "" {
77 | requestID = xid.New().String()
78 | }
79 |
80 | return context.WithValue(ctx, contextkeys.RequestID, requestID)
81 | }
82 |
83 | // initialize all core dependencies for the application
84 | func initialize(ctx context.Context, env string) error {
85 | log := InitLogger(ctx)
86 |
87 | ctx = context.WithValue(ctx, logger.LoggerCtxKey, log)
88 |
89 | // Puts git commit hash into config.
90 | // This is not read automatically because env variable is not in expected format.
91 | if v, found := os.LookupEnv("GIT_COMMIT_HASH"); found {
92 | Config.App.GitCommitHash = v
93 | }
94 |
95 | // Register DB stats prometheus collector
96 | dbInstance, err := DB.Instance(ctx).DB()
97 | if err != nil {
98 | return err
99 | }
100 | collector := sqlstats.NewStatsCollector(Config.Db.URL+"-"+Config.Db.Name, dbInstance)
101 | prometheus.MustRegister(collector)
102 |
103 | return nil
104 | }
105 |
106 | func InitApi(ctx context.Context, env string) error {
107 | err := initialize(ctx, env)
108 | if err != nil {
109 | return err
110 | }
111 |
112 | return nil
113 | }
114 |
115 | func InitMigration(ctx context.Context, env string) error {
116 | err := initialize(ctx, env)
117 | if err != nil {
118 | return err
119 | }
120 |
121 | return nil
122 | }
123 |
124 | // // InitTracing initialises opentracing exporter
125 | // func InitTracing(ctx context.Context) (io.Closer, error) {
126 | // t, closer, err := tracing.Init(Config.Tracing, Logger(ctx))
127 |
128 | // Tracer = t
129 |
130 | // return closer, err
131 | // }
132 |
133 | // NewContext adds core key-value e.g. service name, git hash etc to
134 | // existing context or to a new background context and returns.
135 | func NewContext(ctx context.Context) context.Context {
136 | if ctx == nil {
137 | ctx = context.Background()
138 | }
139 | for k, v := range structs.Map(struct {
140 | GitCommitHash string
141 | Env string
142 | ServiceName string
143 | }{
144 | GitCommitHash: Config.App.GitCommitHash,
145 | Env: Config.App.Env,
146 | ServiceName: Config.App.ServiceName,
147 | }) {
148 | key := strings.ToLower(k)
149 | ctx = context.WithValue(ctx, key, v)
150 | }
151 | return ctx
152 | }
153 |
154 | func InitLogger(ctx context.Context) *logger.ZapLogger {
155 | lgrConfig := logger.Config{
156 | LogLevel: Config.App.LogLevel,
157 | ContextString: "trino-gateway",
158 | }
159 |
160 | Logger, err := logger.NewLogger(lgrConfig)
161 | if err != nil {
162 | panic("failed to initialize logger")
163 | }
164 |
165 | return Logger
166 | }
167 |
--------------------------------------------------------------------------------
/internal/config/config.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "github.com/razorpay/trino-gateway/pkg/spine/db"
5 | )
6 |
7 | type Config struct {
8 | App App
9 | Auth Auth
10 | Db db.Config
11 | Gateway Gateway
12 | Monitor Monitor
13 | }
14 |
15 | // App contains application-specific config values
16 | type App struct {
17 | Env string
18 | GitCommitHash string
19 | LogLevel string
20 | MetricsPort int
21 | Port int
22 | ServiceExternalHostname string
23 | ServiceHostname string
24 | ServiceName string
25 | ShutdownDelay int
26 | ShutdownTimeout int
27 | }
28 |
29 | type Auth struct {
30 | Token string
31 | TokenHeaderKey string
32 | Router struct {
33 | DelegatedAuth struct {
34 | ValidationProviderURL string
35 | ValidationProviderToken string
36 | CacheTTLMinutes string
37 | }
38 | }
39 | }
40 |
41 | type Gateway struct {
42 | DefaultRoutingGroup string
43 | Ports []int
44 | Network string
45 | }
46 |
47 | type Monitor struct {
48 | Interval string
49 | StatsValiditySecs int
50 | Trino struct {
51 | User string
52 | Password string
53 | }
54 | HealthCheckSql string
55 | }
56 |
--------------------------------------------------------------------------------
/internal/constants/contextkeys/contextkeys.go:
--------------------------------------------------------------------------------
1 | package contextkeys
2 |
3 | type contextkeys int
4 |
5 | const RequestID contextkeys = iota
6 |
--------------------------------------------------------------------------------
/internal/frontend/README.md:
--------------------------------------------------------------------------------
1 | ### Notes
2 | Gopherjs compiles to js to run entirely in client browser
3 | unlike java nashorn or similar engine in jvm which allows running js code
4 | we dont hav that luxury,
5 | as a result the gatewayserver API and UI need to reside under same endpoint or else
6 | due to CORS https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS/Errors/CORSMissingAllowOrigin
7 | browsers won't allow connectivity between both.
8 | This also solves the problem of maintaining state for UI as its completely client side
9 |
10 | ### Running in dev mode for testing
11 |
12 | The easiest way to run the examples as WebAssembly is via [`wasmserve`](https://github.com/hajimehoshi/wasmserve).
13 |
14 | Install it (**using Go 1.14+**):
15 |
16 | ```bash
17 | go get -u github.com/hajimehoshi/wasmserve
18 | ```
19 |
20 | Then run an example:
21 |
22 | ```bash
23 | cd trino-gateway/internal/frontend
24 | wasmserve
25 | ```
26 |
27 | Then navigate to http://localhost:8080/
--------------------------------------------------------------------------------
/internal/frontend/components/adminview.go:
--------------------------------------------------------------------------------
1 | package components
2 |
--------------------------------------------------------------------------------
/internal/frontend/components/dashboardview.go:
--------------------------------------------------------------------------------
1 | package components
2 |
--------------------------------------------------------------------------------
/internal/frontend/components/pageview.go:
--------------------------------------------------------------------------------
1 | package components
2 |
3 | import (
4 | "github.com/hexops/vecty"
5 | "github.com/hexops/vecty/elem"
6 | "github.com/razorpay/trino-gateway/internal/frontend/core"
7 | )
8 |
9 | // PageView is a vecty.Component which represents the entire page.
10 | type PageView struct {
11 | vecty.Core
12 | core core.ICore
13 | }
14 |
15 | func GetNewPageViewComponent(c core.ICore) *PageView {
16 | return &PageView{core: c}
17 | }
18 |
19 | func (p *PageView) Render() vecty.ComponentOrHTML {
20 | return elem.Body(
21 | elem.Section(
22 | vecty.Markup(
23 | vecty.Class("section"),
24 | ),
25 | elem.Div(
26 | vecty.Markup(
27 | vecty.Class("container"),
28 | ),
29 | p.renderHeader(),
30 | NewQueryListView(p.core),
31 | p.renderFooter(),
32 | ),
33 | ),
34 | )
35 | }
36 |
37 | func (p *PageView) renderHeader() *vecty.HTML {
38 | return elem.Div(
39 | vecty.Markup(
40 | vecty.Class("tabs", "is-centered", "is-large", "is-fullwidth", "is-toggle", "is-toggle-rounded"),
41 | ),
42 | elem.UnorderedList(
43 | elem.ListItem(
44 | vecty.Markup(
45 | vecty.Class("is-active"),
46 | ),
47 | &TabView{title: "Query History", hrefUrl: "#"},
48 | ),
49 | elem.ListItem(&TabView{title: "Dashboard", hrefUrl: "#"}),
50 | elem.ListItem(&TabView{title: "Admin", hrefUrl: "/admin/swaggerui"}),
51 | ),
52 | )
53 | }
54 |
55 | func (p *PageView) renderFooter() *vecty.HTML {
56 | return elem.Footer(
57 | vecty.Markup(
58 | vecty.Class("footer"),
59 | ),
60 | elem.Div(
61 | vecty.Markup(
62 | vecty.Class("content", "has-text-centered"),
63 | ),
64 | vecty.Text("trino gateway footer"),
65 | ),
66 | )
67 | }
68 |
--------------------------------------------------------------------------------
/internal/frontend/components/querylistview.go:
--------------------------------------------------------------------------------
1 | package components
2 |
3 | import (
4 | "fmt"
5 | "math"
6 |
7 | "github.com/hexops/vecty"
8 | "github.com/hexops/vecty/elem"
9 | "github.com/hexops/vecty/event"
10 | "github.com/hexops/vecty/prop"
11 | "github.com/razorpay/trino-gateway/internal/frontend/core"
12 | )
13 |
14 | // QueryListView is a vecty.Component which represents the query history section
15 | type queryListView struct {
16 | vecty.Core
17 | core core.ICore
18 | items vecty.List
19 | params queryListViewParams
20 | }
21 |
22 | type queryListViewParams struct {
23 | Username string
24 | MaxItems int
25 | TotalItems int
26 | PageIndex int
27 | ItemsPerPage int
28 | }
29 |
30 | func NewQueryListView(core core.ICore) *queryListView {
31 | p := &queryListView{
32 | core: core,
33 | params: queryListViewParams{
34 | PageIndex: 0,
35 | ItemsPerPage: 100,
36 | MaxItems: 1000,
37 | },
38 | }
39 |
40 | if err := p.populateItems(); err != nil {
41 | fmt.Printf("%s: %s\n", "Unable to fetch list of queries.", err.Error())
42 | }
43 | return p
44 | }
45 |
46 | func (p *queryListView) populateItems() error {
47 | // Get total eligible items
48 | queries, err := p.core.GetQueries(p.params.MaxItems, 0, p.params.Username)
49 | if err != nil {
50 | return err
51 | }
52 | p.params.TotalItems = len(queries)
53 | // queries, err = p.core.GetQueries(p.params.ItemsPerPage, p.params.PageIndex*p.params.ItemsPerPage, p.params.Username)
54 | // if err != nil {
55 | // return err
56 | // }
57 | for _, q := range queries {
58 | query := &QueryView{
59 | Query: q,
60 | }
61 | p.items = append(p.items, query)
62 | }
63 | return nil
64 | }
65 |
66 | func (p *queryListView) Render() vecty.ComponentOrHTML {
67 | return elem.Div(
68 | vecty.Markup(
69 | vecty.Class("container", "tile", "is-vertical", "is-ancestor"),
70 | ),
71 | p.renderHeader(),
72 | p.renderItems(),
73 | p.renderPagination(),
74 | )
75 | }
76 |
77 | func (p *queryListView) renderHeader() vecty.ComponentOrHTML {
78 | return elem.Div(
79 | vecty.Markup(
80 | vecty.Class("tile", "is-parent"),
81 | ),
82 | elem.Div(
83 | vecty.Markup(
84 | vecty.Class("tile", "is-child"),
85 | ),
86 | vecty.Text(fmt.Sprintf("Total: %d", p.params.TotalItems)),
87 | ),
88 | elem.Div(
89 | vecty.Markup(
90 | vecty.Class("tile", "is-child"),
91 | ),
92 | elem.Input(
93 | vecty.Markup(
94 | vecty.Class("input"),
95 | prop.Type(prop.TypeText),
96 | prop.Placeholder("Search for Username"), // initial textarea text.
97 | ),
98 |
99 | // When input is typed into the textarea, update the local
100 | // component state and rerender.
101 | // event.Input(func(e *vecty.Event) {
102 | // p.Input = e.Target.Get("value").String()
103 | // vecty.Rerender(p)
104 | // }),
105 | ),
106 | ),
107 | elem.Div(
108 | vecty.Markup(
109 | vecty.Class("tile", "is-child"),
110 | ),
111 | elem.Div(
112 | vecty.Markup(
113 | vecty.Class("select", "is-rounded"),
114 | ),
115 | elem.Select(
116 | elem.Option(vecty.Text("100"+" Entries per page")),
117 | elem.Option(vecty.Text("200"+" Entries per page")),
118 | elem.Option(vecty.Text("500"+" Entries per page")),
119 | ),
120 | ),
121 |
122 | //
123 | //
137 | //
138 | ),
139 | )
140 | }
141 |
142 | func (p *queryListView) onEditSearchbox(e *vecty.Event) {
143 | }
144 |
145 | func (p *queryListView) onClickPageNavigation(_ *vecty.Event) {
146 | p.params.PageIndex = p.params.PageIndex + 1
147 | vecty.Rerender(p)
148 | }
149 |
150 | func (p *queryListView) renderItems() vecty.ComponentOrHTML {
151 | r := vecty.List{}
152 | for i, v := range p.items {
153 | if i >= p.params.PageIndex*p.params.ItemsPerPage && i < (p.params.PageIndex+1)*p.params.ItemsPerPage {
154 | r = append(r, v)
155 | }
156 | }
157 | return elem.OrderedList(r)
158 | }
159 |
160 | func (p *queryListView) renderPagination() vecty.ComponentOrHTML {
161 | totPag := int(math.Ceil(float64(p.params.TotalItems) / float64(p.params.ItemsPerPage)))
162 | currPag := p.params.PageIndex + 1
163 |
164 | return elem.Div(
165 | vecty.Markup(
166 | vecty.Class("tile", "is-parent"),
167 | ),
168 | elem.Div(
169 | vecty.Markup(
170 | vecty.Class("tile", "is-child"),
171 | ),
172 | elem.Navigation(
173 | vecty.Markup(
174 | vecty.Class("pagination", "is-rounded"),
175 | vecty.Property("role", "navigation"),
176 | vecty.Property("aria-label", "pagination"),
177 | ),
178 | elem.Anchor(
179 | vecty.Markup(
180 | vecty.MarkupIf(currPag == 1, vecty.Style("display", "none")),
181 | vecty.Class("pagination-previous"),
182 | ),
183 | vecty.Text("Previous"),
184 | ),
185 | elem.Anchor(
186 | vecty.Markup(
187 | vecty.MarkupIf(currPag == totPag, vecty.Style("display", "none")),
188 | vecty.Class("pagination-next"),
189 | event.Click(p.onClickPageNavigation).PreventDefault(),
190 | ),
191 | vecty.Text("Next page"),
192 | ),
193 | elem.UnorderedList(
194 | vecty.Markup(
195 | vecty.Class("pagination-list"),
196 | ),
197 | elem.ListItem(elem.Anchor(
198 | vecty.Markup(
199 | vecty.MarkupIf(currPag == 1, vecty.Style("display", "none")),
200 | vecty.Class("pagination-link"),
201 | vecty.Property("aria-label", "Goto page 1"),
202 | ),
203 | vecty.Text("1"),
204 | )),
205 | elem.ListItem(elem.Span(
206 | vecty.Markup(
207 | vecty.MarkupIf(currPag == 1, vecty.Style("display", "none")),
208 | vecty.Class("pagination-ellipsis"),
209 | ),
210 | vecty.Text("..."),
211 | )),
212 | elem.ListItem(elem.Anchor(
213 | vecty.Markup(
214 | vecty.Class("pagination-link", "is-current"),
215 | vecty.Property("aria-label", fmt.Sprint("Goto page ", currPag)),
216 | ),
217 | vecty.Text(fmt.Sprint(currPag)),
218 | )),
219 | elem.ListItem(elem.Span(
220 | vecty.Markup(
221 | vecty.MarkupIf(currPag == totPag, vecty.Style("display", "none")),
222 | vecty.Class("pagination-ellipsis"),
223 | ),
224 | vecty.Text("..."),
225 | )),
226 | elem.ListItem(elem.Anchor(
227 | vecty.Markup(
228 | vecty.MarkupIf(currPag == totPag, vecty.Style("display", "none")),
229 | vecty.Class("pagination-link"),
230 | vecty.Property("aria-label", fmt.Sprint("Goto page ", totPag)),
231 | ),
232 | vecty.Text(fmt.Sprint(totPag)),
233 | )),
234 | ),
235 | ),
236 | ),
237 | )
238 | }
239 |
--------------------------------------------------------------------------------
/internal/frontend/components/queryview.go:
--------------------------------------------------------------------------------
1 | package components
2 |
3 | import (
4 | "fmt"
5 | "time"
6 |
7 | "github.com/hexops/vecty"
8 | "github.com/hexops/vecty/elem"
9 | "github.com/hexops/vecty/event"
10 | "github.com/hexops/vecty/prop"
11 | "github.com/hexops/vecty/style"
12 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway"
13 | )
14 |
15 | // QueryView is a vecty.Component which represents a single item in the queryHistory List.
16 | type QueryView struct {
17 | vecty.Core
18 | // core core.ICore
19 |
20 | Query *gatewayv1.Query
21 | classes vecty.ClassMap
22 | }
23 |
24 | /*
25 | QueryId - href to trinoUi
26 | User
27 | backendId
28 | GroupId
29 | submittedAt
30 | Text
31 | */
32 |
33 | func (p *QueryView) Render() vecty.ComponentOrHTML {
34 | return elem.Div(
35 | vecty.Markup(
36 | vecty.Class("box", "tile", "is-parent", "notification", "is-light"),
37 | vecty.Style("display", "flex"),
38 | vecty.Style("flex-direction", "row"),
39 | ),
40 | p.renderMeta(),
41 | p.renderText(),
42 | )
43 | }
44 |
45 | type QueryMetaItem struct {
46 | vecty.Core
47 | k string
48 | v string
49 | }
50 |
51 | // TODO : FIX it
52 | var classes = vecty.ClassMap{
53 | "is-info": true,
54 | "is-light": true,
55 | "is-link": false,
56 | }
57 |
58 | func (q *QueryMetaItem) Render() vecty.ComponentOrHTML {
59 | return vecty.Text(q.v)
60 | }
61 |
62 | func (p *QueryView) renderMeta() *vecty.HTML {
63 | // https://trino-gateway.de.razorpay.com/ui/query.html?20211003_083931_06657_n3mb3
64 | url := fmt.Sprintf("%s/ui/query.html?%s", p.Query.ServerHost, p.Query.Id)
65 |
66 | // fails at runtime in js
67 | // loc, _ := time.LoadLocation("Asia/Kolkata")
68 |
69 | _items := [...]*QueryMetaItem{
70 | {
71 | k: "SubmittedAt",
72 | v: time.Unix(p.Query.GetSubmittedAt(), 0).Local().Format("2006/01/02 15:04:05"), // Golang time format layout is weird stuff.
73 | },
74 | {k: "Username", v: p.Query.GetUsername()},
75 | {k: "BackendId", v: p.Query.GetBackendId()},
76 | {k: "GroupId", v: p.Query.GetGroupId()},
77 | }
78 |
79 | var items vecty.List
80 | for _, i := range _items {
81 | item := elem.ListItem(i)
82 | items = append(items, item)
83 | }
84 |
85 | p.classes = classes
86 |
87 | return elem.Div(
88 | vecty.Markup(
89 | p.classes,
90 | vecty.Class("tile", "is-child", "notification"),
91 | style.Width("30%"),
92 | vecty.Style("display", "flex"),
93 | vecty.Style("flex-direction", "column"),
94 | event.PointerEnter(p.onPointerEnter),
95 | event.PointerLeave(p.onPointerLeave),
96 | ),
97 | elem.Bold(
98 | vecty.Markup(
99 | vecty.Class("subtitle"),
100 | ),
101 | elem.Anchor(
102 | vecty.Markup(
103 | prop.Href(url),
104 | ),
105 | vecty.Text(p.Query.GetId()),
106 | ),
107 | ),
108 | elem.UnorderedList(items),
109 | )
110 | }
111 |
112 | func (p *QueryView) renderText() *vecty.HTML {
113 | return elem.Div(
114 | vecty.Markup(
115 | vecty.Class("tile", "is-child", "is-8"),
116 | style.Width("70%"),
117 | style.Height("6.5em"),
118 | vecty.Style("word-wrap", "break-word"),
119 | style.Overflow(style.OverflowHidden),
120 | vecty.Style("text-overflow", "ellipsis"),
121 | vecty.Style("resize", "vertical"),
122 | vecty.Style("text-align", "center"),
123 | ),
124 | vecty.Text(p.Query.GetText()),
125 | )
126 | }
127 |
128 | func (p *QueryView) onPointerEnter(e *vecty.Event) {
129 | p.classes["is-info"] = false
130 | p.classes["is-link"] = true
131 | vecty.Rerender(p)
132 | }
133 |
134 | func (p *QueryView) onPointerLeave(e *vecty.Event) {
135 | p.classes["is-info"] = true
136 | p.classes["is-link"] = false
137 | vecty.Rerender(p)
138 | }
139 |
--------------------------------------------------------------------------------
/internal/frontend/components/tabview.go:
--------------------------------------------------------------------------------
1 | package components
2 |
3 | import (
4 | "github.com/hexops/vecty"
5 | "github.com/hexops/vecty/elem"
6 | "github.com/hexops/vecty/prop"
7 | )
8 |
9 | // TabView is a vecty.Component which represents a single elements in the tabBar
10 | type TabView struct {
11 | vecty.Core
12 | title string
13 | isSelected bool
14 | // TODO: remove this
15 | hrefUrl string
16 | component *vecty.ComponentOrHTML
17 | }
18 |
19 | func (p *TabView) Render() vecty.ComponentOrHTML {
20 | return elem.Anchor(
21 | vecty.Markup(
22 | vecty.MarkupIf(p.isSelected, vecty.Class("is-active")),
23 | prop.Href(p.hrefUrl),
24 | // event.Click(p.onClick).PreventDefault(),
25 | ),
26 | vecty.Text(p.title),
27 | )
28 | }
29 |
30 | func (p *TabView) onClick(e *vecty.Event) {
31 | }
32 |
--------------------------------------------------------------------------------
/internal/frontend/core/core.go:
--------------------------------------------------------------------------------
1 | package core
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "net/http"
8 |
9 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway"
10 | )
11 |
12 | type GatewayApiClient struct {
13 | Policy gatewayv1.PolicyApi
14 | Backend gatewayv1.BackendApi
15 | Group gatewayv1.GroupApi
16 | Query gatewayv1.QueryApi
17 | }
18 |
19 | type Core struct {
20 | gatewayApiClient *GatewayApiClient
21 | }
22 |
23 | type ICore interface {
24 | GetQueries(count int, skip int, user string) ([]*gatewayv1.Query, error)
25 | }
26 |
27 | func NewCore(gatewayHost string) *Core {
28 | return &Core{
29 | gatewayApiClient: &GatewayApiClient{
30 | Backend: gatewayv1.NewBackendApiProtobufClient(gatewayHost, &http.Client{}),
31 | Group: gatewayv1.NewGroupApiProtobufClient(gatewayHost, &http.Client{}),
32 | Policy: gatewayv1.NewPolicyApiProtobufClient(gatewayHost, &http.Client{}),
33 | Query: gatewayv1.NewQueryApiProtobufClient(gatewayHost, &http.Client{}),
34 | },
35 | }
36 | }
37 |
38 | func (c *Core) GetQueries(count int, skip int, user string) ([]*gatewayv1.Query, error) {
39 | req := gatewayv1.QueriesListRequest{
40 | Skip: int32(skip),
41 | Count: int32(count),
42 | Username: user,
43 | }
44 | queriesResp, err := c.gatewayApiClient.Query.ListQueries(context.Background(), &req)
45 | if err != nil {
46 | println(err.Error())
47 | return nil, errors.New(fmt.Sprint("Unable to Fetch list of queries", err.Error()))
48 | }
49 |
50 | return queriesResp.GetItems(), nil
51 | }
52 |
--------------------------------------------------------------------------------
/internal/frontend/dispatcher/dispatcher.go:
--------------------------------------------------------------------------------
1 | package dispatcher
2 |
3 | // ID is a unique identifier representing a registered callback function.
4 | type ID int
5 |
6 | var (
7 | idCounter ID
8 | callbacks = make(map[ID]func(action interface{}))
9 | )
10 |
11 | // Dispatch dispatches the given action to all registered callbacks.
12 | func Dispatch(action interface{}) {
13 | for _, c := range callbacks {
14 | c(action)
15 | }
16 | }
17 |
18 | // Register registers the callback to handle dispatched actions, the returned
19 | // ID may be used to unregister the callback later.
20 | func Register(callback func(action interface{})) ID {
21 | idCounter++
22 | id := idCounter
23 | callbacks[id] = callback
24 | return id
25 | }
26 |
27 | // Unregister unregisters the callback previously registered via a call to
28 | // Register.
29 | func Unregister(id ID) {
30 | delete(callbacks, id)
31 | }
32 |
--------------------------------------------------------------------------------
/internal/frontend/main.go:
--------------------------------------------------------------------------------
1 | // gopherjs doesnt build with anythign other than package `main`
2 | package main
3 |
4 | import (
5 | "github.com/gopherjs/gopherjs/js"
6 | "github.com/hexops/vecty"
7 | "github.com/razorpay/trino-gateway/internal/frontend/components"
8 | "github.com/razorpay/trino-gateway/internal/frontend/core"
9 | )
10 |
11 | func main() {
12 | path := accessURL() // fmt.Sprint("http://localhost:", "28000")
13 | c := core.NewCore(path)
14 |
15 | vecty.SetTitle("Trino-Gateway")
16 | // vecty.AddStylesheet("https://rawgit.com/tastejs/todomvc-common/master/base.css")
17 | // vecty.AddStylesheet("https://rawgit.com/tastejs/todomvc-app-css/master/index.css")
18 | vecty.AddStylesheet("https://cdn.jsdelivr.net/npm/bulma@0.9.3/css/bulma.min.css")
19 |
20 | vecty.RenderBody(components.GetNewPageViewComponent(c))
21 | }
22 |
23 | var location = js.Global.Get("location")
24 |
25 | func accessURL() string {
26 | // current URL: http://localhost:8000/code/gopherjs/window-location/index.html?a=1
27 |
28 | // return - http://localhost:8000/code/gopherjs/window-location/index.html?a=1
29 | location.Get("href").String()
30 | // return - localhost:8000
31 | location.Get("host").String()
32 | // return - localhost
33 | location.Get("hostname").String()
34 | // return - /code/gopherjs/window-location/index.html
35 | location.Get("pathname").String()
36 | // return - http:
37 | location.Get("protocol").String()
38 | // return - http://localhost:8000
39 | location.Get("origin").String()
40 | // return - 8000
41 | location.Get("port").String()
42 | // return - ?a=1
43 | location.Get("search").String()
44 |
45 | return location.Get("origin").String()
46 | }
47 |
--------------------------------------------------------------------------------
/internal/frontend/server/main.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "context"
5 | "net/http"
6 |
7 | "github.com/NYTimes/gziphandler"
8 | )
9 |
10 | func NewServerHandler(ctx *context.Context) *http.Handler {
11 | guiFs := http.FileServer(http.Dir("./web/frontend"))
12 | appFrontendPath := "/"
13 | h := cacheHandler(
14 | compressionHandler(
15 | http.StripPrefix(appFrontendPath, guiFs),
16 | ),
17 | )
18 |
19 | return &h
20 | }
21 |
22 | func cacheHandler(h http.Handler) http.Handler {
23 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
24 | w.Header().Set("Cache-Control", "max-age=180")
25 | h.ServeHTTP(w, r)
26 | })
27 | }
28 |
29 | func compressionHandler(h http.Handler) http.Handler {
30 | // TODO: check https://github.com/CAFxX/httpcompression
31 | return gziphandler.GzipHandler(h)
32 | }
33 |
--------------------------------------------------------------------------------
/internal/gatewayserver/backendApi/core.go:
--------------------------------------------------------------------------------
1 | package backendapi
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/fatih/structs"
7 | "github.com/razorpay/trino-gateway/internal/gatewayserver/models"
8 | "github.com/razorpay/trino-gateway/internal/gatewayserver/repo"
9 | )
10 |
11 | type Core struct {
12 | backendRepo repo.IBackendRepo
13 | }
14 |
15 | type ICore interface {
16 | CreateOrUpdateBackend(ctx context.Context, params *BackendCreateParams) error
17 | GetBackend(ctx context.Context, id string) (*models.Backend, error)
18 | GetAllBackends(ctx context.Context) ([]models.Backend, error)
19 | GetAllActiveBackends(ctx context.Context) ([]models.Backend, error)
20 | UpdateBackend(ctx context.Context, b *models.Backend) error
21 | DeleteBackend(ctx context.Context, id string) error
22 | EnableBackend(ctx context.Context, id string) error
23 | DisableBackend(ctx context.Context, id string) error
24 | MarkHealthyBackend(ctx context.Context, id string) error
25 | MarkUnhealthyBackend(ctx context.Context, id string) error
26 | }
27 |
28 | func NewCore(backend repo.IBackendRepo) *Core {
29 | return &Core{backendRepo: backend}
30 | }
31 |
32 | // CreateParams has attributes that are required for backend.Create()
33 | type BackendCreateParams struct {
34 | ID string
35 | Hostname string
36 | Scheme string
37 | ExternalUrl string
38 | IsEnabled bool
39 | IsHealthy bool
40 | UptimeSchedule string
41 | ClusterLoad int32
42 | ThresholdClusterLoad int32
43 | StatsUpdatedAt int64
44 | }
45 |
46 | func (c *Core) CreateOrUpdateBackend(ctx context.Context, params *BackendCreateParams) error {
47 | backend := models.Backend{
48 | Hostname: params.Hostname,
49 | Scheme: params.Scheme,
50 | ExternalUrl: ¶ms.ExternalUrl,
51 | IsEnabled: ¶ms.IsEnabled,
52 | IsHealthy: ¶ms.IsHealthy,
53 | UptimeSchedule: ¶ms.UptimeSchedule,
54 | ClusterLoad: ¶ms.ClusterLoad,
55 | ThresholdClusterLoad: ¶ms.ThresholdClusterLoad,
56 | StatsUpdatedAt: ¶ms.StatsUpdatedAt,
57 | }
58 | backend.ID = params.ID
59 |
60 | _, exists := c.backendRepo.Find(ctx, params.ID)
61 | if exists == nil { // update
62 | return c.backendRepo.Update(ctx, &backend)
63 | } else { // create
64 | return c.backendRepo.Create(ctx, &backend)
65 | }
66 | }
67 |
68 | func (c *Core) GetBackend(ctx context.Context, id string) (*models.Backend, error) {
69 | backend, err := c.backendRepo.Find(ctx, id)
70 | return backend, err
71 | }
72 |
73 | func (c *Core) UpdateBackend(ctx context.Context, b *models.Backend) error {
74 | _, exists := c.backendRepo.Find(ctx, b.ID)
75 | if exists != nil {
76 | return exists
77 | }
78 | return c.backendRepo.Update(ctx, b)
79 | }
80 |
81 | func (c *Core) GetAllBackends(ctx context.Context) ([]models.Backend, error) {
82 | backends, err := c.backendRepo.FindMany(ctx, make(map[string]interface{}))
83 | return backends, err
84 | }
85 |
86 | type IFindManyParams interface {
87 | // GetCount() int32
88 | // GetSkip() int32
89 | // GetFrom() int32
90 | // GetTo() int32
91 |
92 | // custom
93 | GetIsEnabled() bool
94 | }
95 |
96 | type FindManyParams struct {
97 | // pagination
98 | // Count int32
99 | // Skip int32
100 | // From int32
101 | // To int32
102 |
103 | // custom
104 | IsEnabled bool `json:"is_enabled"`
105 | }
106 |
107 | func (p *FindManyParams) GetIsEnabled() bool {
108 | return p.IsEnabled
109 | }
110 |
111 | func (c *Core) FindMany(ctx context.Context, params IFindManyParams) ([]models.Backend, error) {
112 | conditionStr := structs.New(params)
113 | // use the json tag name, so we can respect omitempty tags
114 | conditionStr.TagName = "json"
115 | conditions := conditionStr.Map()
116 |
117 | return c.backendRepo.FindMany(ctx, conditions)
118 | }
119 |
120 | func (c *Core) GetAllActiveBackends(ctx context.Context) ([]models.Backend, error) {
121 | backends, err := c.FindMany(ctx, &FindManyParams{IsEnabled: true})
122 | return backends, err
123 | }
124 |
125 | func (c *Core) DeleteBackend(ctx context.Context, id string) error {
126 | return c.backendRepo.Delete(ctx, id)
127 | }
128 |
129 | func (c *Core) EnableBackend(ctx context.Context, id string) error {
130 | return c.backendRepo.Enable(ctx, id)
131 | }
132 |
133 | func (c *Core) DisableBackend(ctx context.Context, id string) error {
134 | return c.backendRepo.Disable(ctx, id)
135 | }
136 |
137 | func (c *Core) MarkHealthyBackend(ctx context.Context, id string) error {
138 | return c.backendRepo.MarkHealthy(ctx, id)
139 | }
140 |
141 | func (c *Core) MarkUnhealthyBackend(ctx context.Context, id string) error {
142 | return c.backendRepo.MarkUnhealthy(ctx, id)
143 | }
144 |
145 | type EvaluateClientParams struct {
146 | ListeningPort int32
147 | }
148 |
--------------------------------------------------------------------------------
/internal/gatewayserver/backendApi/server.go:
--------------------------------------------------------------------------------
1 | package backendapi
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "time"
8 |
9 | _ "github.com/twitchtv/twirp"
10 |
11 | "github.com/razorpay/trino-gateway/internal/gatewayserver/models"
12 | "github.com/razorpay/trino-gateway/internal/provider"
13 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway"
14 | )
15 |
16 | // Server has methods implementing of server rpc.
17 | type Server struct {
18 | core ICore
19 | }
20 |
21 | // NewServer returns a server.
22 | func NewServer(core ICore) *Server {
23 | return &Server{
24 | core: core,
25 | }
26 | }
27 |
28 | // Create creates a new backend
29 | func (s *Server) CreateOrUpdateBackend(ctx context.Context, req *gatewayv1.Backend) (*gatewayv1.Empty, error) {
30 | // defer span.Finish()
31 |
32 | provider.Logger(ctx).Debugw("CreateOrUpdateBackend", map[string]interface{}{
33 | "request": req.String(),
34 | })
35 |
36 | createParams := BackendCreateParams{
37 | ID: req.GetId(),
38 | Scheme: req.GetScheme().Enum().String(),
39 | Hostname: req.GetHostname(),
40 | ExternalUrl: req.GetExternalUrl(),
41 | IsEnabled: req.GetIsEnabled(),
42 | IsHealthy: req.GetIsHealthy(),
43 | UptimeSchedule: req.GetUptimeSchedule(),
44 | ClusterLoad: req.GetClusterLoad(),
45 | ThresholdClusterLoad: req.GetThresholdClusterLoad(),
46 | StatsUpdatedAt: req.GetStatsUpdatedAt(),
47 | }
48 |
49 | err := s.core.CreateOrUpdateBackend(ctx, &createParams)
50 | if err != nil {
51 | return nil, err
52 | }
53 |
54 | return &gatewayv1.Empty{}, nil
55 | }
56 |
57 | // Get retrieves a single backend record
58 | func (s *Server) GetBackend(ctx context.Context, req *gatewayv1.BackendGetRequest) (*gatewayv1.BackendGetResponse, error) {
59 | provider.Logger(ctx).Debugw("GetBackend", map[string]interface{}{
60 | "request": req.String(),
61 | })
62 |
63 | backend, err := s.core.GetBackend(ctx, req.GetId())
64 | if err != nil {
65 | return nil, err
66 | }
67 | backendProto, err := toBackendResponseProto(backend)
68 | if err != nil {
69 | return nil, err
70 | }
71 | return &gatewayv1.BackendGetResponse{Backend: backendProto}, nil
72 | }
73 |
74 | // List fetches a list of filtered backend records
75 | func (s *Server) ListAllBackends(ctx context.Context, req *gatewayv1.Empty) (*gatewayv1.BackendListAllResponse, error) {
76 | provider.Logger(ctx).Debugw("ListAllBackends", map[string]interface{}{
77 | "request": req.String(),
78 | })
79 | backends, err := s.core.GetAllBackends(ctx)
80 | if err != nil {
81 | return nil, err
82 | }
83 |
84 | backendsProto := make([]*gatewayv1.Backend, len(backends))
85 | for i, backendModel := range backends {
86 | backend, err := toBackendResponseProto(&backendModel)
87 | if err != nil {
88 | return nil, err
89 | }
90 | backendsProto[i] = backend
91 | }
92 |
93 | response := gatewayv1.BackendListAllResponse{
94 | Items: backendsProto,
95 | }
96 |
97 | return &response, nil
98 | }
99 |
100 | // Approve marks a backends status to approved
101 |
102 | func (s *Server) EnableBackend(ctx context.Context, req *gatewayv1.BackendEnableRequest) (*gatewayv1.Empty, error) {
103 | provider.Logger(ctx).Debugw("EnableBackend", map[string]interface{}{
104 | "request": req.String(),
105 | })
106 | err := s.core.EnableBackend(ctx, req.GetId())
107 | if err != nil {
108 | return nil, err
109 | }
110 |
111 | return &gatewayv1.Empty{}, nil
112 | }
113 |
114 | func (s *Server) DisableBackend(ctx context.Context, req *gatewayv1.BackendDisableRequest) (*gatewayv1.Empty, error) {
115 | provider.Logger(ctx).Debugw("DisableBackend", map[string]interface{}{
116 | "request": req.String(),
117 | })
118 | err := s.core.DisableBackend(ctx, req.GetId())
119 | if err != nil {
120 | return nil, err
121 | }
122 |
123 | return &gatewayv1.Empty{}, nil
124 | }
125 |
126 | func (s *Server) MarkHealthyBackend(ctx context.Context, req *gatewayv1.BackendMarkHealthyRequest) (*gatewayv1.Empty, error) {
127 | provider.Logger(ctx).Debugw("MarkHealthyBackend", map[string]interface{}{
128 | "request": req.String(),
129 | })
130 | err := s.core.MarkHealthyBackend(ctx, req.GetId())
131 | if err != nil {
132 | return nil, err
133 | }
134 |
135 | return &gatewayv1.Empty{}, nil
136 | }
137 |
138 | func (s *Server) MarkUnhealthyBackend(ctx context.Context, req *gatewayv1.BackendMarkUnhealthyRequest) (*gatewayv1.Empty, error) {
139 | provider.Logger(ctx).Debugw("MarkUnhealthyBackend", map[string]interface{}{
140 | "request": req.String(),
141 | })
142 | err := s.core.MarkUnhealthyBackend(ctx, req.GetId())
143 | if err != nil {
144 | return nil, err
145 | }
146 |
147 | return &gatewayv1.Empty{}, nil
148 | }
149 |
150 | func (s *Server) UpdateClusterLoadBackend(
151 | ctx context.Context,
152 | req *gatewayv1.BackendUpdateClusterLoadRequest,
153 | ) (*gatewayv1.Empty, error) {
154 | provider.Logger(ctx).Debugw("UpdateClusterLoadBackend", map[string]interface{}{
155 | "request": req.String(),
156 | })
157 | b, err := s.core.GetBackend(ctx, req.GetId())
158 | if err != nil {
159 | return nil, err
160 | }
161 | *b.ClusterLoad = req.GetClusterLoad()
162 | *b.StatsUpdatedAt = time.Now().Unix()
163 |
164 | if err := s.core.UpdateBackend(ctx, b); err != nil {
165 | return nil, err
166 | }
167 |
168 | return &gatewayv1.Empty{}, nil
169 | }
170 |
171 | // Delete deletes a backend, soft-delete
172 | func (s *Server) DeleteBackend(ctx context.Context, req *gatewayv1.BackendDeleteRequest) (*gatewayv1.Empty, error) {
173 | provider.Logger(ctx).Debugw("DeleteBackend", map[string]interface{}{
174 | "request": req.String(),
175 | })
176 | err := s.core.DeleteBackend(ctx, req.GetId())
177 | if err != nil {
178 | return nil, err
179 | }
180 |
181 | return &gatewayv1.Empty{}, nil
182 | }
183 |
184 | func toBackendResponseProto(backend *models.Backend) (*gatewayv1.Backend, error) {
185 | if backend == nil {
186 | return &gatewayv1.Backend{}, nil
187 | }
188 | scheme, ok := gatewayv1.Backend_Scheme_value[backend.Scheme]
189 | if !ok {
190 | return nil, errors.New(fmt.Sprint("error encoding response: invalid scheme ", backend.Scheme))
191 | }
192 | response := gatewayv1.Backend{
193 | Id: backend.ID,
194 | Hostname: backend.Hostname,
195 | Scheme: *gatewayv1.Backend_Scheme(scheme).Enum(),
196 | ExternalUrl: *backend.ExternalUrl,
197 | IsEnabled: *backend.IsEnabled,
198 | UptimeSchedule: *backend.UptimeSchedule,
199 | ClusterLoad: *backend.ClusterLoad,
200 | ThresholdClusterLoad: *backend.ThresholdClusterLoad,
201 | StatsUpdatedAt: *backend.StatsUpdatedAt,
202 | IsHealthy: *backend.IsHealthy,
203 | }
204 |
205 | return &response, nil
206 | }
207 |
--------------------------------------------------------------------------------
/internal/gatewayserver/backendApi/validation.go:
--------------------------------------------------------------------------------
1 | package backendapi
2 |
3 | // import (
4 | // validation "github.com/go-ozzo/ozzo-validation/v4"
5 | // )
6 |
7 | // func (cp *CreateParams) Validate() error {
8 | // err := validation.ValidateStruct(cp,
9 | // // id, required, length non zero
10 | // validation.Field(&cp.ID, validation.Required, validation.RuneLength(1, 50)),
11 |
12 | // // Hostname, required, string, length 1-50
13 | // validation.Field(&cp.Hostname, validation.Required, validation.RuneLength(1, 50)),
14 |
15 | // // Scheme, required, string, Union(http, https)
16 | // validation.Field(&cp.Scheme, validation.Required, validation.In("http", "https")),
17 |
18 | // // // last_name, required, string, length 1-30
19 | // // validation.Field(&cp.LastName, validation.Required, validation.RuneLength(1, 30)),
20 | // )
21 |
22 | // return err
23 | // // if err == nil {
24 | // // return nil
25 | // // }
26 |
27 | // // publicErr := errorclass.ErrValidationFailure.New("").
28 | // // Wrap(err).
29 | // // WithPublic(&errors.Public{
30 | // // Description: err.Error(),
31 | // // })
32 |
33 | // // return publicErr
34 | // }
35 |
--------------------------------------------------------------------------------
/internal/gatewayserver/database/dbRepo/dbRepo.go:
--------------------------------------------------------------------------------
1 | package dbRepo
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/razorpay/trino-gateway/pkg/spine"
7 | "github.com/razorpay/trino-gateway/pkg/spine/db"
8 | )
9 |
10 | type DbRepo struct {
11 | spine.Repo
12 | }
13 |
14 | type IDbRepo interface {
15 | Create(ctx context.Context, receiver spine.IModel) error
16 | FindByID(ctx context.Context, receiver spine.IModel, id string) error
17 | FindWithConditionByIDs(ctx context.Context, receivers interface{}, condition map[string]interface{}, ids []string) error
18 | FindMany(ctx context.Context, receivers interface{}, condition map[string]interface{}) error
19 | Delete(ctx context.Context, receiver spine.IModel) error
20 | Update(ctx context.Context, receiver spine.IModel, attrList ...string) error
21 | Preload(ctx context.Context, query string, args ...interface{}) *spine.Repo
22 | ClearAssociations(ctx context.Context, receiver spine.IModel, name string) error
23 | ReplaceAssociations(ctx context.Context, receiver spine.IModel, name string, ass interface{}) error
24 | }
25 |
26 | func NewDbRepo(db *db.DB) IDbRepo {
27 | return &DbRepo{
28 | Repo: spine.Repo{
29 | Db: db,
30 | },
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/internal/gatewayserver/database/migrations/20210805195304_bootstrap.go:
--------------------------------------------------------------------------------
1 | package migration
2 |
3 | import (
4 | "database/sql"
5 |
6 | "github.com/pressly/goose/v3"
7 | )
8 |
9 | func init() {
10 | goose.AddMigration(Up20210805195304, Down20210805195304)
11 | }
12 |
13 | func Up20210805195304(tx *sql.Tx) error {
14 | var err error
15 | _, err = tx.Exec(`
16 | CREATE TABLE backends (
17 | id VARCHAR(255) NOT NULL,
18 | hostname VARCHAR(255) NOT NULL,
19 | scheme ENUM('http', 'https') DEFAULT 'http',
20 | external_url VARCHAR(255) NOT NULL,
21 | is_enabled bool DEFAULT FALSE,
22 | uptime_schedule VARCHAR(255) DEFAULT '* * * * *',
23 | cluster_load INT DEFAULT 0,
24 | threshold_cluster_load INT DEFAULT 0,
25 | stats_updated_at INT(11) NULL,
26 | created_at INT(11) NOT NULL,
27 | updated_at INT(11) NOT NULL,
28 | PRIMARY KEY (id),
29 | KEY users_created_at_index (created_at),
30 | KEY users_updated_at_index (updated_at)
31 | );`)
32 | if err != nil {
33 | return err
34 | }
35 |
36 | // groups is a keyword in mysql, so we use groups_
37 | _, err = tx.Exec("CREATE TABLE `groups_` (" +
38 | `id VARCHAR(255) NOT NULL,
39 | strategy ENUM('random', 'round_robin', 'least_load') DEFAULT 'random',
40 | is_enabled bool DEFAULT FALSE,
41 | last_routed_backend VARCHAR(255),
42 | created_at INT(11) NOT NULL,
43 | updated_at INT(11) NOT NULL,
44 | PRIMARY KEY (id),
45 | KEY users_created_at_index (created_at),
46 | KEY users_updated_at_index (updated_at)
47 | );`)
48 | if err != nil {
49 | return err
50 | }
51 |
52 | _, err = tx.Exec(`CREATE TABLE group_backends_mappings (
53 | id int AUTO_INCREMENT,
54 | group_id varchar(255),
55 | backend_id varchar(255),
56 | created_at int(11),
57 | updated_at int(11),
58 | PRIMARY KEY (id),
59 | UNIQUE KEY (group_id, backend_id),
60 | KEY users_created_at_index (created_at),
61 | KEY users_updated_at_index (updated_at)
62 | );`)
63 | if err != nil {
64 | return err
65 | }
66 |
67 | _, err = tx.Exec(`CREATE TABLE policies (
68 | id varchar(255),
69 | rule_type ENUM ('header_client_tags', 'header_connection_properties', 'header_client_host', 'listening_port'),
70 | rule_value varchar(255),
71 | group_id varchar(255),
72 | fallback_group_id varchar(255),
73 | is_enabled bool,
74 | created_at int(11),
75 | updated_at int(11),
76 | PRIMARY KEY (id),
77 | KEY users_created_at_index (created_at),
78 | KEY users_updated_at_index (updated_at)
79 | );`)
80 | if err != nil {
81 | return err
82 | }
83 |
84 | _, err = tx.Exec(`CREATE TABLE queries (
85 | id varchar(255),
86 | text varchar(255),
87 | client_ip varchar(255),
88 | group_id varchar(255) NULL,
89 | backend_id varchar(255) NULL,
90 | username varchar(255),
91 | server_host varchar(255),
92 | submitted_at int(11),
93 | created_at int(11),
94 | updated_at int(11),
95 | PRIMARY KEY (id),
96 | KEY users_created_at_index (created_at),
97 | KEY users_updated_at_index (updated_at)
98 | );`)
99 | if err != nil {
100 | return err
101 | }
102 |
103 | _, err = tx.Exec(`ALTER TABLE group_backends_mappings ADD FOREIGN KEY (group_id) REFERENCES groups_ (id);`)
104 | if err != nil {
105 | return err
106 | }
107 |
108 | _, err = tx.Exec(`ALTER TABLE group_backends_mappings ADD FOREIGN KEY (backend_id) REFERENCES backends (id) ON DELETE CASCADE ON UPDATE CASCADE;`)
109 | if err != nil {
110 | return err
111 | }
112 |
113 | _, err = tx.Exec(`ALTER TABLE policies ADD FOREIGN KEY (group_id) REFERENCES groups_ (id);`)
114 | if err != nil {
115 | return err
116 | }
117 |
118 | _, err = tx.Exec(`ALTER TABLE policies ADD FOREIGN KEY (fallback_group_id) REFERENCES groups_ (id);`)
119 | if err != nil {
120 | return err
121 | }
122 |
123 | _, err = tx.Exec(`ALTER TABLE queries ADD FOREIGN KEY (group_id) REFERENCES groups_ (id);`)
124 | if err != nil {
125 | return err
126 | }
127 |
128 | _, err = tx.Exec(`ALTER TABLE queries ADD FOREIGN KEY (backend_id) REFERENCES backends (id);`)
129 | if err != nil {
130 | return err
131 | }
132 | return err
133 | }
134 |
135 | func Down20210805195304(tx *sql.Tx) error {
136 | var err error
137 | _, err = tx.Exec(`DROP TABLE IF EXISTS queries;`)
138 | if err != nil {
139 | return err
140 | }
141 | _, err = tx.Exec(`DROP TABLE IF EXISTS policies;`)
142 | if err != nil {
143 | return err
144 | }
145 | _, err = tx.Exec(`DROP TABLE IF EXISTS group_backends_mappings;`)
146 | if err != nil {
147 | return err
148 | }
149 | _, err = tx.Exec(`DROP TABLE IF EXISTS backends;`)
150 | if err != nil {
151 | return err
152 | }
153 | _, err = tx.Exec(`DROP TABLE IF EXISTS groups_;`)
154 | if err != nil {
155 | return err
156 | }
157 | return err
158 | }
159 |
--------------------------------------------------------------------------------
/internal/gatewayserver/database/migrations/20211203205304_alter_backends.go:
--------------------------------------------------------------------------------
1 | package migration
2 |
3 | import (
4 | "database/sql"
5 |
6 | "github.com/pressly/goose/v3"
7 | )
8 |
9 | func init() {
10 | goose.AddMigration(Up20211203205304, Down20211203205304)
11 | }
12 |
13 | func Up20211203205304(tx *sql.Tx) error {
14 | var err error
15 |
16 | _, err = tx.Exec("ALTER TABLE `backends` ADD COLUMN `is_healthy` BOOL DEFAULT false;")
17 | if err != nil {
18 | return err
19 | }
20 | return err
21 | }
22 |
23 | func Down20211203205304(tx *sql.Tx) error {
24 | var err error
25 |
26 | _, err = tx.Exec("ALTER TABLE `backends` DROP COLUMN `is_healthy`;")
27 | if err != nil {
28 | return err
29 | }
30 | return err
31 | }
32 |
--------------------------------------------------------------------------------
/internal/gatewayserver/database/migrations/20220107205304_increase_query_text_size.go:
--------------------------------------------------------------------------------
1 | package migration
2 |
3 | import (
4 | "database/sql"
5 |
6 | "github.com/pressly/goose/v3"
7 | )
8 |
9 | func init() {
10 | goose.AddMigration(Up20220107205304, Down20220107205304)
11 | }
12 |
13 | func Up20220107205304(tx *sql.Tx) error {
14 | var err error
15 |
16 | _, err = tx.Exec("ALTER TABLE `queries` MODIFY COLUMN `text` VARCHAR(500);")
17 | if err != nil {
18 | return err
19 | }
20 | return err
21 | }
22 |
23 | func Down20220107205304(tx *sql.Tx) error {
24 | var err error
25 |
26 | _, err = tx.Exec("ALTER TABLE `queries` MODIFY COLUMN `text` VARCHAR(255);")
27 | if err != nil {
28 | return err
29 | }
30 | return err
31 | }
32 |
--------------------------------------------------------------------------------
/internal/gatewayserver/database/migrations/20240524205304_add_auth_delegation.go:
--------------------------------------------------------------------------------
1 | package migration
2 |
3 | import (
4 | "database/sql"
5 |
6 | "github.com/pressly/goose/v3"
7 | )
8 |
9 | func init() {
10 | goose.AddMigration(Up20240524205304, Down20240524205304)
11 | }
12 |
13 | func Up20240524205304(tx *sql.Tx) error {
14 | var err error
15 |
16 | _, err = tx.Exec("ALTER TABLE `policies` ADD COLUMN `is_auth_delegated` BOOL DEFAULT false;")
17 | if err != nil {
18 | return err
19 | }
20 | return err
21 | }
22 |
23 | func Down20240524205304(tx *sql.Tx) error {
24 | var err error
25 |
26 | _, err = tx.Exec("ALTER TABLE `policies` DROP COLUMN `is_auth_delegated`;")
27 | if err != nil {
28 | return err
29 | }
30 | return err
31 | }
32 |
--------------------------------------------------------------------------------
/internal/gatewayserver/database/migrations/20240525205304_add_set_source.go:
--------------------------------------------------------------------------------
1 | package migration
2 |
3 | import (
4 | "database/sql"
5 |
6 | "github.com/pressly/goose/v3"
7 | )
8 |
9 | func init() {
10 | goose.AddMigration(Up20240525205304, Down20240525205304)
11 | }
12 |
13 | func Up20240525205304(tx *sql.Tx) error {
14 | var err error
15 |
16 | _, err = tx.Exec("ALTER TABLE `policies` ADD COLUMN `set_request_source` VARCHAR(255) DEFAULT '';")
17 | if err != nil {
18 | return err
19 | }
20 | return err
21 | }
22 |
23 | func Down20240525205304(tx *sql.Tx) error {
24 | var err error
25 |
26 | _, err = tx.Exec("ALTER TABLE `policies` DROP COLUMN `set_request_source`;")
27 | if err != nil {
28 | return err
29 | }
30 | return err
31 | }
32 |
--------------------------------------------------------------------------------
/internal/gatewayserver/groupApi/server.go:
--------------------------------------------------------------------------------
1 | package groupapi
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "strings"
8 |
9 | "github.com/razorpay/trino-gateway/internal/gatewayserver/models"
10 | "github.com/razorpay/trino-gateway/internal/provider"
11 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway"
12 | _ "github.com/twitchtv/twirp"
13 | )
14 |
15 | // Server has methods implementing of server rpc.
16 | type Server struct {
17 | core ICore
18 | }
19 |
20 | // NewServer returns a server.
21 | func NewServer(core ICore) *Server {
22 | return &Server{
23 | core: core,
24 | }
25 | }
26 |
27 | // Create creates a new group
28 | func (s *Server) CreateOrUpdateGroup(ctx context.Context, req *gatewayv1.Group) (*gatewayv1.Empty, error) {
29 | // defer span.Finish()
30 |
31 | provider.Logger(ctx).Debugw("CreateOrUpdateGroup", map[string]interface{}{
32 | "request": req.String(),
33 | })
34 |
35 | createParams := GroupCreateParams{
36 | ID: req.GetId(),
37 | Strategy: req.GetStrategy().Enum().String(),
38 | Backends: req.GetBackends(),
39 | IsEnabled: req.GetIsEnabled(),
40 | LastRoutedBackend: req.GetLastRoutedBackend(),
41 | }
42 |
43 | err := s.core.CreateOrUpdateGroup(ctx, &createParams)
44 | if err != nil {
45 | return nil, err
46 | }
47 |
48 | return &gatewayv1.Empty{}, nil
49 | }
50 |
51 | // Get retrieves a single group record
52 | func (s *Server) GetGroup(ctx context.Context, req *gatewayv1.GroupGetRequest) (*gatewayv1.GroupGetResponse, error) {
53 | provider.Logger(ctx).Debugw("GetGroup", map[string]interface{}{
54 | "request": req.String(),
55 | })
56 |
57 | group, err := s.core.GetGroup(ctx, req.GetId())
58 | if err != nil {
59 | return nil, err
60 | }
61 | groupProto, err := toGroupResponseProto(group)
62 | if err != nil {
63 | return nil, err
64 | }
65 | return &gatewayv1.GroupGetResponse{Group: groupProto}, nil
66 | }
67 |
68 | // List fetches a list of filtered group records
69 | func (s *Server) ListAllGroups(ctx context.Context, req *gatewayv1.Empty) (*gatewayv1.GroupListAllResponse, error) {
70 | provider.Logger(ctx).Debugw("ListAllGroups", map[string]interface{}{
71 | "request": req.String(),
72 | })
73 | groups, err := s.core.GetAllGroups(ctx)
74 | if err != nil {
75 | return nil, err
76 | }
77 |
78 | groupsProto := make([]*gatewayv1.Group, len(groups))
79 | for i, groupModel := range groups {
80 | group, err := toGroupResponseProto(&groupModel)
81 | if err != nil {
82 | return nil, err
83 | }
84 | groupsProto[i] = group
85 | }
86 |
87 | response := gatewayv1.GroupListAllResponse{
88 | Items: groupsProto,
89 | }
90 |
91 | return &response, nil
92 | }
93 |
94 | // Approve marks a groups status to approved
95 |
96 | func (s *Server) EnableGroup(ctx context.Context, req *gatewayv1.GroupEnableRequest) (*gatewayv1.Empty, error) {
97 | provider.Logger(ctx).Debugw("EnableGroup", map[string]interface{}{
98 | "request": req.String(),
99 | })
100 | err := s.core.EnableGroup(ctx, req.GetId())
101 | if err != nil {
102 | return nil, err
103 | }
104 |
105 | return &gatewayv1.Empty{}, nil
106 | }
107 |
108 | func (s *Server) DisableGroup(ctx context.Context, req *gatewayv1.GroupDisableRequest) (*gatewayv1.Empty, error) {
109 | provider.Logger(ctx).Debugw("DisableGroup", map[string]interface{}{
110 | "request": req.String(),
111 | })
112 | err := s.core.DisableGroup(ctx, req.GetId())
113 | if err != nil {
114 | return nil, err
115 | }
116 |
117 | return &gatewayv1.Empty{}, nil
118 | }
119 |
120 | // Delete deletes a group, soft-delete
121 | func (s *Server) DeleteGroup(ctx context.Context, req *gatewayv1.GroupDeleteRequest) (*gatewayv1.Empty, error) {
122 | provider.Logger(ctx).Debugw("DeleteGroup", map[string]interface{}{
123 | "request": req.String(),
124 | })
125 | err := s.core.DeleteGroup(ctx, req.GetId())
126 | if err != nil {
127 | return nil, err
128 | }
129 |
130 | return &gatewayv1.Empty{}, nil
131 | }
132 |
133 | func toGroupResponseProto(group *models.Group) (*gatewayv1.Group, error) {
134 | if group == nil {
135 | return &gatewayv1.Group{}, nil
136 | }
137 | strategy, ok := gatewayv1.Group_RoutingStrategy_value[strings.ToUpper(*group.Strategy)]
138 | if !ok {
139 | return nil, errors.New(fmt.Sprint("error encoding response: invalid strategy ", *group.Strategy))
140 | }
141 | var backends []string
142 | for _, backend := range group.GroupBackendsMappings {
143 | backends = append(backends, backend.BackendId)
144 | }
145 | response := gatewayv1.Group{
146 | Id: group.ID,
147 | Strategy: *gatewayv1.Group_RoutingStrategy(strategy).Enum(),
148 | Backends: backends,
149 | IsEnabled: *group.IsEnabled,
150 | LastRoutedBackend: *group.LastRoutedBackend,
151 | }
152 |
153 | return &response, nil
154 | }
155 |
156 | func (s *Server) EvaluateBackendForGroups(ctx context.Context, req *gatewayv1.EvaluateBackendRequest) (*gatewayv1.EvaluateBackendResponse, error) {
157 | provider.Logger(ctx).Debugw("EvaluateBackendForGroups", map[string]interface{}{
158 | "request": req.String(),
159 | })
160 |
161 | backend_id, group_id, err := s.core.EvaluateBackendForGroups(ctx, req.GetGroupIds())
162 | if err != nil {
163 | return nil, err
164 |
165 | }
166 | return &gatewayv1.EvaluateBackendResponse{BackendId: backend_id, GroupId: group_id}, nil
167 | }
168 |
--------------------------------------------------------------------------------
/internal/gatewayserver/groupApi/validation.go:
--------------------------------------------------------------------------------
1 | package groupapi
2 |
3 | // import (
4 | // validation "github.com/go-ozzo/ozzo-validation/v4"
5 | // )
6 |
7 | // func (cp *CreateParams) Validate() error {
8 | // err := validation.ValidateStruct(cp,
9 | // // id, required, length non zero
10 | // validation.Field(&cp.ID, validation.Required, validation.RuneLength(1, 50)),
11 |
12 | // // Hostname, required, string, length 1-50
13 | // validation.Field(&cp.Hostname, validation.Required, validation.RuneLength(1, 50)),
14 |
15 | // // Scheme, required, string, Union(http, https)
16 | // validation.Field(&cp.Scheme, validation.Required, validation.In("http", "https")),
17 |
18 | // // // last_name, required, string, length 1-30
19 | // // validation.Field(&cp.LastName, validation.Required, validation.RuneLength(1, 30)),
20 | // )
21 |
22 | // return err
23 | // // if err == nil {
24 | // // return nil
25 | // // }
26 |
27 | // // publicErr := errorclass.ErrValidationFailure.New("").
28 | // // Wrap(err).
29 | // // WithPublic(&errors.Public{
30 | // // Description: err.Error(),
31 | // // })
32 |
33 | // // return publicErr
34 | // }
35 |
--------------------------------------------------------------------------------
/internal/gatewayserver/healthApi/core.go:
--------------------------------------------------------------------------------
1 | package healthapi
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "sync"
7 | )
8 |
9 | // Core holds business logic and/or orchestrator of other things in the package.
10 | type Core struct {
11 | isHealthy bool
12 | mutex sync.Mutex
13 | }
14 |
15 | // NewCore creates Core.
16 | func NewCore() *Core {
17 | return &Core{
18 | isHealthy: true,
19 | }
20 | }
21 |
22 | // RunHealthCheck runs various server checks and returns true if all individual components are working fine.
23 | // Todo: Fix server check response per https://tools.ietf.org/id/draft-inadarei-api-health-check-01.html :)
24 | func (c *Core) RunHealthCheck(ctx context.Context) (bool, error) {
25 | if !c.isHealthy {
26 | return false, fmt.Errorf("server marked unhealthy")
27 | }
28 |
29 | var err error
30 |
31 | // Checks the DB connection exists and is alive by executing a select query.
32 | // err = c.repo.Alive(ctx)
33 | if err != nil {
34 | // logger.Ctx(ctx).Errorw("failed to execute select query on db connection", "error", err)
35 |
36 | // Check fallback routing group exists & has atleast 1 active backend - IF NO set state to INVALID
37 | //
38 | }
39 | isDbAlive := err == nil
40 |
41 | return isDbAlive, err
42 | }
43 |
44 | // MarkUnhealthy marks the server as unhealthy for health check to return negative
45 | func (c *Core) MarkUnhealthy() {
46 | c.mutex.Lock()
47 | c.isHealthy = false
48 | c.mutex.Unlock()
49 | }
50 |
--------------------------------------------------------------------------------
/internal/gatewayserver/healthApi/server.go:
--------------------------------------------------------------------------------
1 | package healthapi
2 |
3 | import (
4 | "context"
5 |
6 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway"
7 | "github.com/twitchtv/twirp"
8 | )
9 |
10 | // Server has methods implementing of server rpc.
11 | type Server struct {
12 | core *Core
13 | }
14 |
15 | // NewServer returns a server.
16 | func NewServer(core *Core) *Server {
17 | return &Server{
18 | core: core,
19 | }
20 | }
21 |
22 | // Check returns service's serving status.
23 | func (s *Server) Check(ctx context.Context, req *gatewayv1.HealthCheckRequest) (*gatewayv1.HealthCheckResponse, error) {
24 | var status gatewayv1.HealthCheckResponse_ServingStatus
25 | ok, err := s.core.RunHealthCheck(ctx)
26 | if !ok {
27 | status = gatewayv1.HealthCheckResponse_SERVING_STATUS_NOT_SERVING
28 | return &gatewayv1.HealthCheckResponse{ServingStatus: status}, twirp.NewError(twirp.Unavailable, err.Error())
29 | }
30 | status = gatewayv1.HealthCheckResponse_SERVING_STATUS_SERVING
31 | return &gatewayv1.HealthCheckResponse{ServingStatus: status}, nil
32 | }
33 |
--------------------------------------------------------------------------------
/internal/gatewayserver/hooks/auth.go:
--------------------------------------------------------------------------------
1 | package hooks
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "net/http"
7 | "strings"
8 |
9 | "github.com/twitchtv/twirp"
10 |
11 | "github.com/razorpay/trino-gateway/internal/boot"
12 | )
13 |
14 | type contextkey int
15 |
16 | const (
17 | authTokenCtxKey contextkey = iota
18 | authUrlPathCtxKey
19 | )
20 |
21 | func Auth() *twirp.ServerHooks {
22 | hooks := &twirp.ServerHooks{}
23 |
24 | hooks.RequestReceived = func(ctx context.Context) (context.Context, error) {
25 | m, _ := ctx.Value(authUrlPathCtxKey).(string)
26 | if strings.Contains(m, "/Get") || strings.Contains(m, "/List") {
27 | return ctx, nil
28 | }
29 |
30 | token, _ := ctx.Value(authTokenCtxKey).(string)
31 |
32 | if token == "" {
33 | return ctx, twirp.NewError(
34 | twirp.Unauthenticated,
35 | fmt.Sprint(
36 | "empty/undefined apiToken in header: ",
37 | boot.Config.Auth.TokenHeaderKey),
38 | )
39 | }
40 |
41 | if boot.Config.Auth.Token == token {
42 | return ctx, nil
43 | }
44 |
45 | return ctx, twirp.NewError(twirp.Unauthenticated, "invalid apiToken for authentication")
46 | }
47 |
48 | return hooks
49 | }
50 |
51 | func WithAuth(h http.Handler) http.Handler {
52 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
53 | ctx := r.Context()
54 | token := r.Header.Get(boot.Config.Auth.TokenHeaderKey)
55 | urlPath := r.URL.Path
56 |
57 | ctx = context.WithValue(ctx, authTokenCtxKey, token)
58 | ctx = context.WithValue(ctx, authUrlPathCtxKey, urlPath)
59 |
60 | r = r.WithContext(ctx)
61 |
62 | h.ServeHTTP(w, r)
63 | })
64 | }
65 |
--------------------------------------------------------------------------------
/internal/gatewayserver/hooks/ctx.go:
--------------------------------------------------------------------------------
1 | package hooks
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/razorpay/trino-gateway/pkg/logger"
7 | "github.com/twitchtv/twirp"
8 |
9 | "github.com/razorpay/trino-gateway/internal/boot"
10 | "github.com/razorpay/trino-gateway/internal/provider"
11 | )
12 |
13 | // Ctx returns function which sets context with core service
14 | // information and puts contextual logger into same for later use.
15 | func Ctx() *twirp.ServerHooks {
16 | hooks := &twirp.ServerHooks{}
17 |
18 | hooks.RequestRouted = func(ctx context.Context) (context.Context, error) {
19 | ctx = boot.WithRequestID(ctx, "")
20 |
21 | // Adds more contextual info in above logger.
22 | // Todo: Check why method, service and package names are not known in this hook.
23 | reqMethod, _ := twirp.MethodName(ctx)
24 | reqService, _ := twirp.ServiceName(ctx)
25 | reqPackage, _ := twirp.PackageName(ctx)
26 | req := map[string]interface{}{
27 | "reqId": boot.GetRequestID(ctx),
28 | // TODO: set auth user in auth Hook
29 | // "reqUser": ctx.Value("authUserCtxKey"),
30 | "reqMethod": reqMethod,
31 | "reqService": reqService,
32 | "reqPackage": reqPackage,
33 | }
34 |
35 | return context.WithValue(ctx, logger.LoggerCtxKey, provider.Logger(ctx).WithFields(req)), nil
36 | }
37 |
38 | return hooks
39 | }
40 |
--------------------------------------------------------------------------------
/internal/gatewayserver/hooks/metric.go:
--------------------------------------------------------------------------------
1 | package hooks
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "time"
7 |
8 | "github.com/razorpay/trino-gateway/internal/gatewayserver/metrics"
9 | "github.com/twitchtv/twirp"
10 | )
11 |
12 | var reqStartTsCtxKey = new(int)
13 |
14 | // Metric returns function which puts unique request id into context.
15 | func Metric() *twirp.ServerHooks {
16 | hooks := &twirp.ServerHooks{}
17 |
18 | // RequestReceived:
19 | hooks.RequestReceived = func(ctx context.Context) (context.Context, error) {
20 | ctx = markRequestStart(ctx)
21 |
22 | return ctx, nil
23 | }
24 |
25 | // RequestRouted:
26 | hooks.RequestRouted = func(ctx context.Context) (context.Context, error) {
27 | pkg, _ := twirp.PackageName(ctx)
28 | service, _ := twirp.ServiceName(ctx)
29 | method, _ := twirp.MethodName(ctx)
30 |
31 | metrics.RequestsReceivedTotal.
32 | WithLabelValues(pkg, service, method).
33 | Inc()
34 |
35 | return ctx, nil
36 | }
37 |
38 | // ResponseSent:
39 | hooks.ResponseSent = func(ctx context.Context) {
40 | start, _ := getRequestStart(ctx)
41 | pkg, _ := twirp.PackageName(ctx)
42 | service, _ := twirp.ServiceName(ctx)
43 | method, _ := twirp.MethodName(ctx)
44 | statusCode, _ := twirp.StatusCode(ctx)
45 |
46 | duration := float64(time.Now().Sub(start).Milliseconds())
47 |
48 | metrics.ResponsesSentTotal.WithLabelValues(
49 | pkg, service, method,
50 | fmt.Sprintf("%v", statusCode),
51 | ).Inc()
52 |
53 | metrics.ResponseDurations.WithLabelValues(
54 | pkg, service, method,
55 | fmt.Sprintf("%v", statusCode),
56 | ).Observe(duration)
57 | }
58 |
59 | return hooks
60 | }
61 |
62 | func markRequestStart(ctx context.Context) context.Context {
63 | return context.WithValue(ctx, reqStartTsCtxKey, time.Now())
64 | }
65 |
66 | func getRequestStart(ctx context.Context) (time.Time, bool) {
67 | t, ok := ctx.Value(reqStartTsCtxKey).(time.Time)
68 | return t, ok
69 | }
70 |
--------------------------------------------------------------------------------
/internal/gatewayserver/hooks/requestid.go:
--------------------------------------------------------------------------------
1 | package hooks
2 |
3 | import (
4 | "context"
5 | "net/http"
6 |
7 | "github.com/rs/xid"
8 | "github.com/twitchtv/twirp"
9 |
10 | "github.com/razorpay/trino-gateway/internal/constants/contextkeys"
11 | )
12 |
13 | const (
14 | requestIDHttpHeaderKey = "X-Request-ID"
15 | )
16 |
17 | // RequestID returns function which puts unique request id into context.
18 | func RequestID() *twirp.ServerHooks {
19 | hooks := &twirp.ServerHooks{}
20 |
21 | hooks.RequestRouted = func(ctx context.Context) (context.Context, error) {
22 | // var err error
23 | requestID, _ := ctx.Value(contextkeys.RequestID).(string)
24 | if requestID == "" {
25 | requestID = xid.New().String()
26 | }
27 | ctx = context.WithValue(ctx, contextkeys.RequestID, requestID)
28 | return ctx, nil
29 | }
30 |
31 | return hooks
32 | }
33 |
34 | // WithRequestID is a http handler which puts specific http request header into
35 | // request context which is made available in twirp hooks.
36 | // Refer: https://twitchtv.github.io/twirp/docs/headers.html
37 | func WithRequestID(h http.Handler) http.Handler {
38 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
39 | ctx := setCtxRequestId(r.Context(), r)
40 | r = r.WithContext(ctx)
41 | h.ServeHTTP(w, r)
42 | })
43 | }
44 |
45 | func setCtxRequestId(ctx context.Context, r *http.Request) context.Context {
46 | requestID := r.Header.Get(requestIDHttpHeaderKey)
47 | return context.WithValue(ctx, contextkeys.RequestID, requestID)
48 | }
49 |
--------------------------------------------------------------------------------
/internal/gatewayserver/metrics/metrics.go:
--------------------------------------------------------------------------------
1 | package metrics
2 |
3 | import (
4 | "github.com/prometheus/client_golang/prometheus"
5 | "github.com/prometheus/client_golang/prometheus/promauto"
6 | "github.com/razorpay/trino-gateway/internal/boot"
7 | )
8 |
9 | var (
10 | RequestsReceivedTotal *prometheus.CounterVec
11 | ResponsesSentTotal *prometheus.CounterVec
12 | ResponseDurations *prometheus.HistogramVec
13 | FallbackGroupInvoked *prometheus.CounterVec
14 | )
15 |
16 | func init() {
17 | env := boot.Config.App.Env
18 | RequestsReceivedTotal = promauto.NewCounterVec(
19 | prometheus.CounterOpts{
20 | Name: "trino_gateway_http_requests_total",
21 | Help: "Number of HTTP requests received.",
22 | },
23 | []string{"env", "package", "server", "method"},
24 | ).MustCurryWith(prometheus.Labels{"env": env})
25 |
26 | ResponsesSentTotal = promauto.NewCounterVec(
27 | prometheus.CounterOpts{
28 | Name: "trino_gateway_http_responses_total",
29 | Help: "Number of HTTP responses sent.",
30 | },
31 | []string{"env", "package", "server", "method", "code"},
32 | ).MustCurryWith(prometheus.Labels{"env": env})
33 |
34 | ResponseDurations = promauto.NewHistogramVec(
35 | prometheus.HistogramOpts{
36 | Name: "trino_gateway_http_durations_ms_histogram",
37 | Help: "HTTP latency distributions histogram.",
38 | Buckets: []float64{2, 5, 10, 15, 25, 40, 60, 85, 120, 150, 200, 300},
39 | },
40 | []string{"env", "package", "server", "method", "code"},
41 | ).MustCurryWith(prometheus.Labels{"env": env}).(*prometheus.HistogramVec)
42 |
43 | FallbackGroupInvoked = promauto.NewCounterVec(
44 | prometheus.CounterOpts{
45 | Name: "trino_gateway_fallback_group_invoked_total",
46 | Help: "Number of requests where fallback group routing was invoked",
47 | },
48 | []string{"env"},
49 | ).MustCurryWith(prometheus.Labels{"env": env})
50 | }
51 |
--------------------------------------------------------------------------------
/internal/gatewayserver/models/backend.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import (
4 | validation "github.com/go-ozzo/ozzo-validation"
5 | "github.com/razorpay/trino-gateway/pkg/spine"
6 | )
7 |
8 | // backend model struct definition
9 | type Backend struct {
10 | spine.Model
11 | Hostname string `json:"hostname"`
12 | Scheme string `json:"scheme"`
13 | ExternalUrl *string `json:"external_url"`
14 | IsEnabled *bool `json:"is_enabled"`
15 | IsHealthy *bool `json:"is_healthy"`
16 | UptimeSchedule *string `json:"uptime_schedule" gorm:"default:'* * * * *';"`
17 | ClusterLoad *int32 `json:"cluster_load"`
18 | ThresholdClusterLoad *int32 `json:"threshold_cluster_load"`
19 | StatsUpdatedAt *int64 `json:"stats_updated_at"`
20 | }
21 |
22 | func (u *Backend) TableName() string {
23 | return "backends"
24 | }
25 |
26 | func (u *Backend) EntityName() string {
27 | return "backend"
28 | }
29 |
30 | func (u *Backend) SetDefaults() error {
31 | return nil
32 | }
33 |
34 | func (u *Backend) Validate() error {
35 | // fmt.Printf("{%v}\n", *u.StatsUpdatedAt)
36 | err := validation.ValidateStruct(u,
37 | // id, required, length non zero
38 | validation.Field(&u.ID, validation.Required, validation.RuneLength(1, 50)),
39 |
40 | // url, required, string, length 1-30
41 | validation.Field(&u.Hostname, validation.Required, validation.RuneLength(1, 255)),
42 |
43 | // Scheme, required, string, Union(http, https)
44 | validation.Field(&u.Scheme, validation.Required, validation.In("http", "https")),
45 |
46 | // first_name, required, string, length 1-30
47 | validation.Field(&u.ExternalUrl, validation.Required, validation.RuneLength(1, 255)),
48 |
49 | // validation.Field(&u.StatsUpdatedAt, validation.By(datatype.IsTimestamp)),
50 |
51 | // status, required, string
52 | // validation.Field(&u.IsEnabled, validation.Required, validation.In(true, false)),
53 | )
54 |
55 | return err
56 | }
57 |
--------------------------------------------------------------------------------
/internal/gatewayserver/models/group.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import "github.com/razorpay/trino-gateway/pkg/spine"
4 |
5 | // group model struct definition
6 | type Group struct {
7 | spine.Model
8 | Strategy *string `json:"strategy"`
9 | IsEnabled *bool `json:"is_enabled" sql:"DEFAULT:true"`
10 | LastRoutedBackend *string `json:"last_routed_backend"`
11 | GroupBackendsMappings []GroupBackendsMapping `gorm:"foreignKey:GroupId;references:ID"`
12 | }
13 |
14 | func (u *Group) TableName() string {
15 | return "groups_"
16 | }
17 |
18 | func (u *Group) EntityName() string {
19 | return "group"
20 | }
21 |
22 | func (u *Group) SetDefaults() error {
23 | return nil
24 | }
25 |
26 | func (u *Group) Validate() error {
27 | return nil
28 | }
29 |
30 | type GroupBackendsMapping struct {
31 | spine.Model
32 | ID *int32 `json:"id" sql:"DEFAULT:NULL"`
33 | GroupId string `json:"group_id" gorm:"primaryKey"`
34 | BackendId string `json:"backend_id" gorm:"primaryKey"`
35 | }
36 |
37 | func (u *GroupBackendsMapping) TableName() string {
38 | return "group_backends_mappings"
39 | }
40 |
41 | func (u *GroupBackendsMapping) EntityName() string {
42 | return "group_backends_mappings"
43 | }
44 |
45 | func (u *GroupBackendsMapping) SetDefaults() error {
46 | return nil
47 | }
48 |
49 | func (u *GroupBackendsMapping) Validate() error {
50 | return nil
51 | }
52 |
--------------------------------------------------------------------------------
/internal/gatewayserver/models/policy.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import "github.com/razorpay/trino-gateway/pkg/spine"
4 |
5 | // policy model struct definition
6 | type Policy struct {
7 | spine.Model
8 | RuleType string `json:"rule_type"`
9 | RuleValue string `json:"rule_value"`
10 | GroupId string `json:"group_id"`
11 | FallbackGroupId *string `json:"fallback_group_id"`
12 | IsEnabled *bool `json:"is_enabled" sql:"DEFAULT:true"`
13 | IsAuthDelegated *bool `json:"is_auth_delegated" sql:"DEFAULT:false"`
14 | SetRequestSource *string `json:"set_request_source"`
15 | }
16 |
17 | func (u *Policy) TableName() string {
18 | return "policies"
19 | }
20 |
21 | func (u *Policy) EntityName() string {
22 | return "policy"
23 | }
24 |
25 | func (u *Policy) SetDefaults() error {
26 | return nil
27 | }
28 |
29 | func (u *Policy) Validate() error {
30 | return nil
31 | }
32 |
--------------------------------------------------------------------------------
/internal/gatewayserver/models/query.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import "github.com/razorpay/trino-gateway/pkg/spine"
4 |
5 | // query model struct definition
6 | type Query struct {
7 | spine.Model
8 | Text string `json:"text"`
9 | ClientIp string `json:"client_ip"`
10 | GroupId string `json:"group_id"`
11 | BackendId string `json:"backend_id"`
12 | Username string `json:"username"`
13 | SubmittedAt int64 `json:"submitted_at"`
14 | ServerHost string `json:"server_host"`
15 | }
16 |
17 | func (u *Query) TableName() string {
18 | return "queries"
19 | }
20 |
21 | func (u *Query) EntityName() string {
22 | return "query"
23 | }
24 |
25 | func (u *Query) SetDefaults() error {
26 | return nil
27 | }
28 |
29 | func (u *Query) Validate() error {
30 | return nil
31 | }
32 |
--------------------------------------------------------------------------------
/internal/gatewayserver/policyApi/core_test.go:
--------------------------------------------------------------------------------
1 | package policyapi
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/assert"
7 | )
8 |
9 | func Test_setIntersection(t *testing.T) {
10 | s_empty := map[string]struct{}{}
11 | s_nil := map[string]struct{}(nil)
12 | s1 := map[string]struct{}{
13 | "a": {},
14 | "b": {},
15 | }
16 | s2 := map[string]struct{}{
17 | "c": {},
18 | "d": {},
19 | }
20 | s3 := map[string]struct{}{
21 | "a": {},
22 | "d": {},
23 | }
24 | s13 := map[string]struct{}{
25 | "a": {},
26 | }
27 | s23 := map[string]struct{}{
28 | "d": {},
29 | }
30 |
31 | // both empty sets
32 | assert.Equal(t, s_empty, setIntersection(s_empty, s_empty))
33 |
34 | // first set empty
35 | assert.Equal(t, s_empty, setIntersection(s_empty, s1))
36 |
37 | // second set empty
38 | assert.Equal(t, s_empty, setIntersection(s1, s_empty))
39 |
40 | // both non-empty sets with empty intersection
41 | assert.Equal(t, s_empty, setIntersection(s1, s2))
42 |
43 | // both non-empty sets with non-empty intersection
44 | assert.Equal(t, s13, setIntersection(s1, s3))
45 | assert.Equal(t, s23, setIntersection(s3, s2))
46 | // reverse order
47 | assert.Equal(t, s13, setIntersection(s3, s1))
48 | assert.Equal(t, s23, setIntersection(s2, s3))
49 |
50 | // one set nil
51 | assert.Equal(t, s3, setIntersection(s_nil, s3))
52 |
53 | // both sets nil
54 | assert.Equal(t, s_nil, setIntersection(s_nil, s_nil))
55 |
56 | // nested stuff
57 | assert.Equal(t, s13, setIntersection(setIntersection(setIntersection(s1, s_nil), s_nil), s3))
58 | }
59 |
--------------------------------------------------------------------------------
/internal/gatewayserver/policyApi/validation.go:
--------------------------------------------------------------------------------
1 | package policyapi
2 |
3 | // import (
4 | // validation "github.com/go-ozzo/ozzo-validation/v4"
5 | // )
6 |
7 | // func (cp *CreateParams) Validate() error {
8 | // err := validation.ValidateStruct(cp,
9 | // // id, required, length non zero
10 | // validation.Field(&cp.ID, validation.Required, validation.RuneLength(1, 50)),
11 |
12 | // // Hostname, required, string, length 1-50
13 | // validation.Field(&cp.Hostname, validation.Required, validation.RuneLength(1, 50)),
14 |
15 | // // Scheme, required, string, Union(http, https)
16 | // validation.Field(&cp.Scheme, validation.Required, validation.In("http", "https")),
17 |
18 | // // // last_name, required, string, length 1-30
19 | // // validation.Field(&cp.LastName, validation.Required, validation.RuneLength(1, 30)),
20 | // )
21 |
22 | // return err
23 | // // if err == nil {
24 | // // return nil
25 | // // }
26 |
27 | // // publicErr := errorclass.ErrValidationFailure.New("").
28 | // // Wrap(err).
29 | // // WithPublic(&errors.Public{
30 | // // Description: err.Error(),
31 | // // })
32 |
33 | // // return publicErr
34 | // }
35 |
--------------------------------------------------------------------------------
/internal/gatewayserver/queryApi/core.go:
--------------------------------------------------------------------------------
1 | package queryapi
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/fatih/structs"
7 | "github.com/razorpay/trino-gateway/internal/gatewayserver/models"
8 | "github.com/razorpay/trino-gateway/internal/gatewayserver/repo"
9 | fetcherPkg "github.com/razorpay/trino-gateway/pkg/fetcher"
10 | )
11 |
12 | var entityName string = (&models.Query{}).EntityName()
13 |
14 | type Core struct {
15 | queryRepo repo.IQueryRepo
16 | fetcher fetcherPkg.IClient
17 | }
18 |
19 | type ICore interface {
20 | CreateOrUpdateQuery(ctx context.Context, params *QueryCreateParams) error
21 | GetQuery(ctx context.Context, id string) (*models.Query, error)
22 | FindMany(ctx context.Context, params IFindManyParams) ([]models.Query, error)
23 | }
24 |
25 | func NewCore(query repo.IQueryRepo, fetcher fetcherPkg.IClient) *Core {
26 | if !fetcher.IsEntityRegistered(entityName) {
27 | fetcher.Register(entityName, &models.Query{}, &[]models.Query{})
28 | }
29 | return &Core{
30 | queryRepo: query,
31 | fetcher: fetcher,
32 | }
33 | }
34 |
35 | // CreateParams has attributes that are required for query.Create()
36 | type QueryCreateParams struct {
37 | ID string
38 | Text string
39 | ClientIp string
40 | BackendId string
41 | Username string
42 | GroupId string
43 | ServerHost string
44 | SubmittedAt int64
45 | }
46 |
47 | func (c *Core) CreateOrUpdateQuery(ctx context.Context, params *QueryCreateParams) error {
48 | query := models.Query{
49 | Text: params.Text,
50 | ClientIp: params.ClientIp,
51 | BackendId: params.BackendId,
52 | Username: params.Username,
53 | GroupId: params.GroupId,
54 | ServerHost: params.ServerHost,
55 | SubmittedAt: params.SubmittedAt,
56 | }
57 | query.ID = params.ID
58 | _, exists := c.queryRepo.Find(ctx, params.ID)
59 | if exists == nil { // update
60 | return c.queryRepo.Update(ctx, &query)
61 | } else { // create
62 | return c.queryRepo.Create(ctx, &query)
63 | }
64 | }
65 |
66 | func (c *Core) GetQuery(ctx context.Context, id string) (*models.Query, error) {
67 | query, err := c.queryRepo.Find(ctx, id)
68 | return query, err
69 | }
70 |
71 | type IFindManyParams interface {
72 | GetCount() int32
73 | GetSkip() int32
74 | GetFrom() int64
75 | GetTo() int64
76 | // GetOrderBy() string
77 |
78 | // custom
79 | GetUsername() string
80 | GetBackendId() string
81 | GetGroupId() string
82 | }
83 |
84 | type Filters struct {
85 | // custom
86 | Username string `json:"username,omitempty"`
87 | BackendId string `json:"backend_id,omitempty"`
88 | GroupId string `json:"group_id,omitempty"`
89 | }
90 |
91 | func (c *Core) FindMany(ctx context.Context, params IFindManyParams) ([]models.Query, error) {
92 | conditionStr := structs.New(Filters{
93 | Username: params.GetUsername(),
94 | BackendId: params.GetBackendId(),
95 | GroupId: params.GetGroupId(),
96 | })
97 | // use the json tag name, so we can respect omitempty tags
98 | conditionStr.TagName = "json"
99 | conditions := conditionStr.Map()
100 |
101 | // return c.queryRepo.FindMany(ctx, conditions)
102 |
103 | pagination := fetcherPkg.Pagination{}
104 |
105 | pagination.Skip = int(params.GetSkip())
106 | pagination.Limit = int(params.GetCount())
107 |
108 | t := params
109 | timeRange := fetcherPkg.TimeRange{}
110 | if params.GetFrom() != int64(0) && params.GetTo() != int64(0) {
111 | timeRange.From = t.GetFrom()
112 | timeRange.To = t.GetTo()
113 | }
114 |
115 | fetchRequest := fetcherPkg.FetchMultipleRequest{
116 | EntityName: entityName,
117 | Filter: conditions,
118 | Pagination: pagination,
119 | TimeRange: timeRange,
120 | IsTrashed: false,
121 | HasCreatedAt: true,
122 | }
123 |
124 | resp, err := c.fetcher.FetchMultiple(ctx, fetchRequest)
125 | if err != nil {
126 | return nil, err
127 | }
128 |
129 | queries := (resp.GetEntities().(map[string]interface{})[entityName]).(*[]models.Query)
130 |
131 | return *queries, nil
132 | }
133 |
--------------------------------------------------------------------------------
/internal/gatewayserver/queryApi/server.go:
--------------------------------------------------------------------------------
1 | package queryapi
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/razorpay/trino-gateway/internal/gatewayserver/models"
7 | "github.com/razorpay/trino-gateway/internal/provider"
8 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway"
9 | _ "github.com/twitchtv/twirp"
10 | )
11 |
12 | // Server has methods implementing of server rpc.
13 | type Server struct {
14 | core ICore
15 | }
16 |
17 | // NewServer returns a server.
18 | func NewServer(core ICore) *Server {
19 | return &Server{
20 | core: core,
21 | }
22 | }
23 |
24 | func (s *Server) CreateOrUpdateQuery(ctx context.Context, req *gatewayv1.Query) (*gatewayv1.Empty, error) {
25 | provider.Logger(ctx).Debugw("CreateOrUpdateQuery", map[string]interface{}{
26 | "request": req.String(),
27 | })
28 |
29 | createParams := QueryCreateParams{
30 | ID: req.GetId(),
31 | Text: req.GetText(),
32 | ClientIp: req.GetClientIp(),
33 | GroupId: req.GetGroupId(),
34 | BackendId: req.GetBackendId(),
35 | Username: req.GetUsername(),
36 | ServerHost: req.GetServerHost(),
37 | SubmittedAt: req.GetSubmittedAt(),
38 | }
39 |
40 | err := s.core.CreateOrUpdateQuery(ctx, &createParams)
41 | if err != nil {
42 | return nil, err
43 | }
44 |
45 | return &gatewayv1.Empty{}, nil
46 | }
47 |
48 | func (s *Server) GetQuery(ctx context.Context, req *gatewayv1.QueryGetRequest) (*gatewayv1.QueryGetResponse, error) {
49 | provider.Logger(ctx).Debugw("GetQuery", map[string]interface{}{
50 | "request": req.String(),
51 | })
52 | query, err := s.core.GetQuery(ctx, req.GetId())
53 | if err != nil {
54 | return nil, err
55 | }
56 | queryProto, err := toQueryResponseProto(query)
57 | if err != nil {
58 | return nil, err
59 | }
60 | return &gatewayv1.QueryGetResponse{Query: queryProto}, nil
61 | }
62 |
63 | func (s *Server) ListQueries(ctx context.Context, req *gatewayv1.QueriesListRequest) (*gatewayv1.QueriesListResponse, error) {
64 | provider.Logger(ctx).Debugw("ListQueries", map[string]interface{}{
65 | "request": req.String(),
66 | })
67 | // TODO
68 |
69 | if err := ValidateMultiFetchRequest(ctx, req); err != nil {
70 | return nil, err
71 | }
72 |
73 | queries, err := s.core.FindMany(ctx, req)
74 | if err != nil {
75 | return nil, err
76 | }
77 |
78 | queriesProto := make([]*gatewayv1.Query, len(queries))
79 | for i, queryModel := range queries {
80 | query, err := toQueryResponseProto(&queryModel)
81 | if err != nil {
82 | return nil, err
83 | }
84 | queriesProto[i] = query
85 | }
86 |
87 | response := gatewayv1.QueriesListResponse{
88 | Items: queriesProto,
89 | Count: int32(len(queriesProto)),
90 | }
91 |
92 | return &response, nil
93 | }
94 |
95 | func toQueryResponseProto(query *models.Query) (*gatewayv1.Query, error) {
96 | if query == nil {
97 | return &gatewayv1.Query{}, nil
98 | }
99 | return &gatewayv1.Query{
100 | Id: query.ID,
101 | Text: query.Text,
102 | ServerHost: query.ServerHost,
103 | ClientIp: query.ClientIp,
104 | GroupId: query.GroupId,
105 | BackendId: query.BackendId,
106 | Username: query.Username,
107 | SubmittedAt: query.SubmittedAt,
108 | }, nil
109 | }
110 |
111 | func (s *Server) FindBackendForQuery(ctx context.Context, req *gatewayv1.FindBackendForQueryRequest) (*gatewayv1.FindBackendForQueryResponse, error) {
112 | provider.Logger(ctx).Debugw("FindBackendForQuery", map[string]interface{}{
113 | "request": req.String(),
114 | })
115 |
116 | query, err := s.core.GetQuery(ctx, req.QueryId)
117 | if err != nil {
118 | return nil, err
119 | }
120 | return &gatewayv1.FindBackendForQueryResponse{
121 | BackendId: query.BackendId,
122 | GroupId: query.GroupId,
123 | }, nil
124 | }
125 |
--------------------------------------------------------------------------------
/internal/gatewayserver/queryApi/validation.go:
--------------------------------------------------------------------------------
1 | package queryapi
2 |
3 | import (
4 | "context"
5 |
6 | // validation "github.com/go-ozzo/ozzo-validation/v4"
7 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway"
8 | )
9 |
10 | // func (cp *CreateParams) Validate() error {
11 | // err := validation.ValidateStruct(cp,
12 | // // id, required, length non zero
13 | // validation.Field(&cp.ID, validation.Required, validation.RuneLength(1, 50)),
14 |
15 | // // Hostname, required, string, length 1-50
16 | // validation.Field(&cp.Hostname, validation.Required, validation.RuneLength(1, 50)),
17 |
18 | // // Scheme, required, string, Union(http, https)
19 | // validation.Field(&cp.Scheme, validation.Required, validation.In("http", "https")),
20 |
21 | // // // last_name, required, string, length 1-30
22 | // // validation.Field(&cp.LastName, validation.Required, validation.RuneLength(1, 30)),
23 | // )
24 |
25 | // return err
26 | // // if err == nil {
27 | // // return nil
28 | // // }
29 |
30 | // // publicErr := errorclass.ErrValidationFailure.New("").
31 | // // Wrap(err).
32 | // // WithPublic(&errors.Public{
33 | // // Description: err.Error(),
34 | // // })
35 |
36 | // // return publicErr
37 | // }
38 |
39 | func ValidateMultiFetchRequest(ctx context.Context, req *gatewayv1.QueriesListRequest) error {
40 | // err := validation.ValidateStruct(
41 | // req,
42 | // validation.Field(&req.EntityName,
43 | // validation.Required))
44 | // if err != nil {
45 | // return rzperror.New(ctx, errorCodes.VALIDATION_ERROR, err).Report()
46 | // }
47 | return nil
48 | }
49 |
--------------------------------------------------------------------------------
/internal/gatewayserver/repo/backend.go:
--------------------------------------------------------------------------------
1 | package repo
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/razorpay/trino-gateway/internal/gatewayserver/database/dbRepo"
7 | "github.com/razorpay/trino-gateway/internal/gatewayserver/models"
8 | "github.com/razorpay/trino-gateway/internal/provider"
9 | "github.com/razorpay/trino-gateway/pkg/spine"
10 | )
11 |
12 | type IBackendRepo interface {
13 | Create(ctx context.Context, backend *models.Backend) error
14 | Update(ctx context.Context, backend *models.Backend) error
15 | Find(ctx context.Context, id string) (*models.Backend, error)
16 | FindMany(ctx context.Context, conditions map[string]interface{}) ([]models.Backend, error)
17 | // GetAll(ctx context.Context) ([]models.Backend, error)
18 | // GetAllActive(ctx context.Context) ([]models.Backend, error)
19 | GetAllActiveByIDs(ctx context.Context, ids []string) ([]models.Backend, error)
20 | Delete(ctx context.Context, id string) error
21 | Enable(ctx context.Context, id string) error
22 | Disable(ctx context.Context, id string) error
23 | MarkHealthy(ctx context.Context, id string) error
24 | MarkUnhealthy(ctx context.Context, id string) error
25 | }
26 |
27 | type BackendRepo struct {
28 | repo dbRepo.IDbRepo
29 | }
30 |
31 | // NewBackendRepo returns a new instance of *BackendRepo
32 | func NewBackendRepo(repo dbRepo.IDbRepo) *BackendRepo {
33 | return &BackendRepo{repo: repo}
34 | }
35 |
36 | func (r *BackendRepo) Create(ctx context.Context, backend *models.Backend) error {
37 | err := r.repo.Create(ctx, backend)
38 | if err != nil {
39 | provider.Logger(ctx).WithError(err).Error("backend create failed")
40 | return err
41 | }
42 |
43 | provider.Logger(ctx).Infow("backend created", map[string]interface{}{"backend_id": backend.ID})
44 |
45 | return nil
46 | }
47 |
48 | func (r *BackendRepo) Update(ctx context.Context, backend *models.Backend) error {
49 | err := r.repo.Update(ctx, backend)
50 | if err != nil {
51 | if err == spine.NoRowAffected {
52 | provider.Logger(ctx).Debugw(
53 | "no row affected by backend update",
54 | map[string]interface{}{"backend_id": backend.ID},
55 | )
56 | return nil
57 | }
58 | provider.Logger(ctx).WithError(err).Errorw(
59 | "backend update failed",
60 | map[string]interface{}{"backend_id": backend.ID},
61 | )
62 | return err
63 | }
64 |
65 | provider.Logger(ctx).Infow("backend updated", map[string]interface{}{"backend_id": backend.ID})
66 |
67 | return nil
68 | }
69 |
70 | func (r *BackendRepo) Find(ctx context.Context, id string) (*models.Backend, error) {
71 | backend := models.Backend{}
72 |
73 | err := r.repo.FindByID(ctx, &backend, id)
74 | if err != nil {
75 | return nil, err
76 | }
77 |
78 | return &backend, nil
79 | }
80 |
81 | // Returns list of backend ids which are ready to serve traffic
82 | func (r *BackendRepo) GetAllActiveByIDs(ctx context.Context, ids []string) ([]models.Backend, error) {
83 | var backends []models.Backend
84 |
85 | err := r.repo.FindWithConditionByIDs(
86 | ctx,
87 | &backends,
88 | map[string]interface{}{"is_enabled": true, "is_healthy": true},
89 | ids,
90 | )
91 | if err != nil {
92 | return nil, err
93 | }
94 |
95 | return backends, nil
96 | }
97 |
98 | func (r *BackendRepo) FindMany(ctx context.Context, conditions map[string]interface{}) ([]models.Backend, error) {
99 | var backends []models.Backend
100 |
101 | err := r.repo.FindMany(ctx, &backends, conditions)
102 | if err != nil {
103 | return nil, err
104 | }
105 |
106 | return backends, nil
107 | }
108 |
109 | // func (r *BackendRepo) GetAll(ctx context.Context) ([]models.Backend, error) {
110 | // var backends []models.Backend
111 |
112 | // err := r.repo.FindMany(ctx, &backends, make(map[string]interface{}))
113 | // if err != nil {
114 | // return nil, err
115 | // }
116 |
117 | // return &backends, nil
118 | // }
119 |
120 | // func (r *BackendRepo) GetAllActive(ctx context.Context) ([]models.Backend, error) {
121 | // var backends []models.Backend
122 |
123 | // err := r.repo.FindMany(ctx, &backends, map[string]interface{}{"is_enabled": true})
124 | // if err != nil {
125 | // return nil, err
126 | // }
127 |
128 | // return &backends, nil
129 | // }
130 |
131 | func (r *BackendRepo) Enable(ctx context.Context, id string) error {
132 | provider.Logger(ctx).Infow("backend activation triggered", map[string]interface{}{"backend_id": id})
133 |
134 | backend, err := r.Find(ctx, id)
135 | if err != nil {
136 | provider.Logger(ctx).Error("backend activation failed: " + err.Error())
137 | return err
138 | }
139 |
140 | if *backend.IsEnabled {
141 | provider.Logger(ctx).Info("backend activation failed. Already active")
142 | return nil
143 | }
144 |
145 | *backend.IsEnabled = true
146 |
147 | if err := r.repo.Update(ctx, backend); err != nil {
148 | return err
149 | }
150 |
151 | return nil
152 | }
153 |
154 | func (r *BackendRepo) Disable(ctx context.Context, id string) error {
155 | provider.Logger(ctx).Infow("backend deactivation triggered", map[string]interface{}{"backend_id": id})
156 |
157 | backend, err := r.Find(ctx, id)
158 | if err != nil {
159 | provider.Logger(ctx).Error("backend deactivation failed: " + err.Error())
160 | return err
161 | }
162 |
163 | if !*backend.IsEnabled {
164 | provider.Logger(ctx).Info("backend deactivation failed. Already inactive")
165 | return nil
166 | }
167 |
168 | *backend.IsEnabled = false
169 |
170 | if err := r.repo.Update(ctx, backend); err != nil {
171 | return err
172 | }
173 |
174 | return nil
175 | }
176 |
177 | func (r *BackendRepo) MarkHealthy(ctx context.Context, id string) error {
178 | provider.Logger(ctx).Infow("backend mark as healthy triggered", map[string]interface{}{"backend_id": id})
179 |
180 | backend, err := r.Find(ctx, id)
181 | if err != nil {
182 | provider.Logger(ctx).Error("backend mark as healthy failed: " + err.Error())
183 | return err
184 | }
185 |
186 | if *backend.IsHealthy {
187 | provider.Logger(ctx).Info("backend mark as healthy failed. Already healthy")
188 | return nil
189 | }
190 |
191 | *backend.IsHealthy = true
192 |
193 | if err := r.repo.Update(ctx, backend); err != nil {
194 | return err
195 | }
196 |
197 | return nil
198 | }
199 |
200 | func (r *BackendRepo) MarkUnhealthy(ctx context.Context, id string) error {
201 | provider.Logger(ctx).Infow("backend mark as unhealthy triggered", map[string]interface{}{"backend_id": id})
202 |
203 | backend, err := r.Find(ctx, id)
204 | if err != nil {
205 | provider.Logger(ctx).Error("backend mark as unhealthy failed: " + err.Error())
206 | return err
207 | }
208 |
209 | if !*backend.IsHealthy {
210 | provider.Logger(ctx).Info("backend mark as unhealthy failed. Already unhealthy")
211 | return nil
212 | }
213 |
214 | *backend.IsHealthy = false
215 |
216 | if err := r.repo.Update(ctx, backend); err != nil {
217 | return err
218 | }
219 |
220 | return nil
221 | }
222 |
223 | func (r *BackendRepo) Delete(ctx context.Context, id string) error {
224 | provider.Logger(ctx).Infow("backend delete request", map[string]interface{}{"backend_id": id})
225 |
226 | backend, err := r.Find(ctx, id)
227 | if err != nil {
228 | provider.Logger(ctx).Error("backend delete failed: " + err.Error())
229 | return err
230 | }
231 |
232 | // _ = backend
233 |
234 | err = r.repo.Delete(ctx, backend)
235 | if err != nil {
236 | return err
237 | }
238 |
239 | return nil
240 | }
241 |
--------------------------------------------------------------------------------
/internal/gatewayserver/repo/group.go:
--------------------------------------------------------------------------------
1 | package repo
2 |
3 | import (
4 | "context"
5 | "errors"
6 |
7 | "github.com/razorpay/trino-gateway/internal/gatewayserver/database/dbRepo"
8 | "github.com/razorpay/trino-gateway/internal/gatewayserver/models"
9 | "github.com/razorpay/trino-gateway/internal/provider"
10 | "github.com/razorpay/trino-gateway/pkg/spine"
11 | "gorm.io/gorm/clause"
12 | )
13 |
14 | type IGroupRepo interface {
15 | Create(ctx context.Context, group *models.Group) error
16 | Update(ctx context.Context, group *models.Group) error
17 | Find(ctx context.Context, id string) (*models.Group, error)
18 | FindMany(ctx context.Context, conditions map[string]interface{}) ([]models.Group, error)
19 | // GetAll(ctx context.Context) ([]models.Group, error)
20 | // GetAllActive(ctx context.Context) ([]models.Group, error)
21 | Delete(ctx context.Context, id string) error
22 | Enable(ctx context.Context, id string) error
23 | Disable(ctx context.Context, id string) error
24 | }
25 |
26 | type GroupRepo struct {
27 | repo dbRepo.IDbRepo
28 | }
29 |
30 | // NewCore returns a new instance of *Core
31 | func NewGroupRepo(repo dbRepo.IDbRepo) *GroupRepo {
32 | return &GroupRepo{repo: repo}
33 | }
34 |
35 | func (r *GroupRepo) Create(ctx context.Context, group *models.Group) error {
36 | err := r.repo.Create(ctx, group)
37 | if err != nil {
38 | provider.Logger(ctx).WithError(err).Errorw("group create failed", map[string]interface{}{"group_id": group.ID})
39 | return err
40 | }
41 |
42 | provider.Logger(ctx).Infow("group created", map[string]interface{}{"group_id": group.ID})
43 |
44 | return nil
45 | }
46 |
47 | func (r *GroupRepo) Update(ctx context.Context, group *models.Group) error {
48 | if group.GroupBackendsMappings != nil {
49 | err := r.repo.ReplaceAssociations(ctx, group, "GroupBackendsMappings", group.GroupBackendsMappings)
50 | if err != nil {
51 | provider.Logger(ctx).WithError(err).Errorw(
52 | "group update failed, unable to update associations",
53 | map[string]interface{}{"group_id": group.ID})
54 | return err
55 | }
56 | }
57 | err := r.repo.Update(ctx, group)
58 | if err != nil {
59 | if err == spine.NoRowAffected {
60 | provider.Logger(ctx).Debugw(
61 | "no row affected by group update",
62 | map[string]interface{}{"group_id": group.ID},
63 | )
64 | return nil
65 | }
66 | provider.Logger(ctx).WithError(err).Errorw(
67 | "group update failed",
68 | map[string]interface{}{"group_id": group.ID})
69 | return err
70 | }
71 |
72 | provider.Logger(ctx).Infow("group updated", map[string]interface{}{"group_id": group.ID})
73 |
74 | return nil
75 | }
76 |
77 | func (r *GroupRepo) Find(ctx context.Context, id string) (*models.Group, error) {
78 | group := models.Group{}
79 |
80 | err := r.repo.Preload(ctx, clause.Associations).FindByID(ctx, &group, id)
81 | if err != nil {
82 | return nil, err
83 | }
84 |
85 | return &group, nil
86 | }
87 |
88 | func (r *GroupRepo) FindMany(ctx context.Context, conditions map[string]interface{}) ([]models.Group, error) {
89 | var groups []models.Group
90 |
91 | err := r.repo.Preload(ctx, clause.Associations).FindMany(ctx, &groups, conditions)
92 | if err != nil {
93 | return nil, err
94 | }
95 |
96 | return groups, nil
97 | }
98 |
99 | func (r *GroupRepo) Enable(ctx context.Context, id string) error {
100 | provider.Logger(ctx).Infow("group activation triggered", map[string]interface{}{"group_id": id})
101 |
102 | group, err := r.Find(ctx, id)
103 | if err != nil {
104 | provider.Logger(ctx).Error("group activation failed: " + err.Error())
105 | return err
106 | }
107 |
108 | if *group.IsEnabled {
109 | provider.Logger(ctx).Error("group activation failed. Already active")
110 | return errors.New("already active")
111 | }
112 |
113 | *group.IsEnabled = true
114 |
115 | if err := r.repo.Update(ctx, group); err != nil {
116 | return err
117 | }
118 |
119 | return nil
120 | }
121 |
122 | func (r *GroupRepo) Disable(ctx context.Context, id string) error {
123 | provider.Logger(ctx).Infow("group activation triggered", map[string]interface{}{"group_id": id})
124 |
125 | group, err := r.Find(ctx, id)
126 | if err != nil {
127 | provider.Logger(ctx).Error("group activation failed: " + err.Error())
128 | return err
129 | }
130 |
131 | if !*group.IsEnabled {
132 | provider.Logger(ctx).Error("group deactivation failed. Already inactive")
133 | return errors.New("already inactive")
134 | }
135 |
136 | *group.IsEnabled = false
137 |
138 | if err := r.repo.Update(ctx, group); err != nil {
139 | return err
140 | }
141 |
142 | return nil
143 | }
144 |
145 | func (r *GroupRepo) Delete(ctx context.Context, id string) error {
146 | provider.Logger(ctx).Infow("group delete request", map[string]interface{}{"group_id": id})
147 |
148 | group, err := r.Find(ctx, id)
149 | if err != nil {
150 | provider.Logger(ctx).Error("group delete failed: " + err.Error())
151 | return err
152 | }
153 |
154 | // _ = group
155 |
156 | err = r.repo.Delete(ctx, group)
157 | if err != nil {
158 | return err
159 | }
160 |
161 | return nil
162 | }
163 |
--------------------------------------------------------------------------------
/internal/gatewayserver/repo/policy.go:
--------------------------------------------------------------------------------
1 | package repo
2 |
3 | import (
4 | "context"
5 | "errors"
6 |
7 | "github.com/razorpay/trino-gateway/internal/gatewayserver/database/dbRepo"
8 | "github.com/razorpay/trino-gateway/internal/gatewayserver/models"
9 | "github.com/razorpay/trino-gateway/internal/provider"
10 | "github.com/razorpay/trino-gateway/pkg/spine"
11 | )
12 |
13 | type IPolicyRepo interface {
14 | Create(ctx context.Context, policy *models.Policy) error
15 | Update(ctx context.Context, policy *models.Policy) error
16 | Find(ctx context.Context, id string) (*models.Policy, error)
17 | FindMany(ctx context.Context, conditions map[string]interface{}) ([]models.Policy, error)
18 | // GetAll(ctx context.Context) ([]models.Policy, error)
19 | // GetAllActive(ctx context.Context) ([]models.Policy, error)
20 | Delete(ctx context.Context, id string) error
21 | Enable(ctx context.Context, id string) error
22 | Disable(ctx context.Context, id string) error
23 | }
24 |
25 | type PolicyRepo struct {
26 | repo dbRepo.IDbRepo
27 | }
28 |
29 | // NewCore returns a new instance of *Core
30 | func NewPolicyRepo(repo dbRepo.IDbRepo) *PolicyRepo {
31 | return &PolicyRepo{repo: repo}
32 | }
33 |
34 | func (r *PolicyRepo) Create(ctx context.Context, policy *models.Policy) error {
35 | err := r.repo.Create(ctx, policy)
36 | if err != nil {
37 | provider.Logger(ctx).WithError(err).Errorw("policy create failed", map[string]interface{}{"id": policy.ID})
38 | return err
39 | }
40 |
41 | provider.Logger(ctx).Infow("policy created", map[string]interface{}{"id": policy.ID})
42 |
43 | return nil
44 | }
45 |
46 | func (r *PolicyRepo) Update(ctx context.Context, policy *models.Policy) error {
47 | err := r.repo.Update(ctx, policy)
48 | if err != nil {
49 | if err == spine.NoRowAffected {
50 | provider.Logger(ctx).Debugw(
51 | "no row affected by policy update",
52 | map[string]interface{}{"policy_id": policy.ID},
53 | )
54 | return nil
55 | }
56 | provider.Logger(ctx).WithError(err).Errorw(
57 | "policy update failed",
58 | map[string]interface{}{"policy_id": policy.ID})
59 | return err
60 | }
61 |
62 | provider.Logger(ctx).Infow("policy updated", map[string]interface{}{"id": policy.ID})
63 |
64 | return nil
65 | }
66 |
67 | func (r *PolicyRepo) Find(ctx context.Context, id string) (*models.Policy, error) {
68 | policy := models.Policy{}
69 |
70 | err := r.repo.FindByID(ctx, &policy, id)
71 | if err != nil {
72 | return nil, err
73 | }
74 |
75 | return &policy, nil
76 | }
77 |
78 | func (r *PolicyRepo) FindMany(ctx context.Context, conditions map[string]interface{}) ([]models.Policy, error) {
79 | var policies []models.Policy
80 |
81 | err := r.repo.FindMany(ctx, &policies, conditions)
82 | if err != nil {
83 | return nil, err
84 | }
85 |
86 | return policies, nil
87 | }
88 |
89 | // func (r *PolicyRepo) GetAll(ctx context.Context) ([]models.Policy, error) {
90 | // var policies []models.Policy
91 |
92 | // err := r.repo.FindMany(ctx, &policies, make(map[string]interface{}))
93 | // if err != nil {
94 | // return nil, err
95 | // }
96 |
97 | // return &policies, nil
98 | // }
99 |
100 | // func (r *PolicyRepo) GetAllActive(ctx context.Context) ([]models.Policy, error) {
101 | // var policies []models.Policy
102 |
103 | // err := r.repo.FindMany(ctx, &policies, map[string]interface{}{"is_enabled": true})
104 | // if err != nil {
105 | // return nil, err
106 | // }
107 |
108 | // return &policies, nil
109 | // }
110 |
111 | func (r *PolicyRepo) Enable(ctx context.Context, id string) error {
112 | provider.Logger(ctx).Infow("policy activation triggered", map[string]interface{}{"policy_id": id})
113 |
114 | policy, err := r.Find(ctx, id)
115 | if err != nil {
116 | provider.Logger(ctx).Error("policy activation failed: " + err.Error())
117 | return err
118 | }
119 |
120 | if *policy.IsEnabled {
121 | provider.Logger(ctx).Error("policy activation failed. Already active")
122 | return errors.New("Already active")
123 | }
124 |
125 | *policy.IsEnabled = true
126 |
127 | if err := r.repo.Update(ctx, policy); err != nil {
128 | return err
129 | }
130 |
131 | return nil
132 | }
133 |
134 | func (r *PolicyRepo) Disable(ctx context.Context, id string) error {
135 | provider.Logger(ctx).Infow("policy activation triggered", map[string]interface{}{"policy_id": id})
136 |
137 | policy, err := r.Find(ctx, id)
138 | if err != nil {
139 | provider.Logger(ctx).Error("policy activation failed: " + err.Error())
140 | return err
141 | }
142 |
143 | if !*policy.IsEnabled {
144 | provider.Logger(ctx).Error("policy activation failed. Already active")
145 | return errors.New("Already active")
146 | }
147 |
148 | *policy.IsEnabled = false
149 |
150 | if err := r.repo.Update(ctx, policy); err != nil {
151 | return err
152 | }
153 |
154 | return nil
155 | }
156 |
157 | func (r *PolicyRepo) Delete(ctx context.Context, id string) error {
158 | provider.Logger(ctx).Infow("policy delete request", map[string]interface{}{"policy_id": id})
159 |
160 | policy, err := r.Find(ctx, id)
161 | if err != nil {
162 | provider.Logger(ctx).Error("policy delete failed: " + err.Error())
163 | return err
164 | }
165 |
166 | // _ = policy
167 |
168 | err = r.repo.Delete(ctx, policy)
169 | if err != nil {
170 | return err
171 | }
172 |
173 | return nil
174 | }
175 |
--------------------------------------------------------------------------------
/internal/gatewayserver/repo/query.go:
--------------------------------------------------------------------------------
1 | package repo
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/razorpay/trino-gateway/internal/gatewayserver/database/dbRepo"
7 | "github.com/razorpay/trino-gateway/internal/gatewayserver/models"
8 | "github.com/razorpay/trino-gateway/internal/provider"
9 | "github.com/razorpay/trino-gateway/pkg/spine"
10 | )
11 |
12 | type IQueryRepo interface {
13 | Create(ctx context.Context, query *models.Query) error
14 | Update(ctx context.Context, query *models.Query) error
15 | Find(ctx context.Context, id string) (*models.Query, error)
16 | FindMany(ctx context.Context, conditions map[string]interface{}) ([]models.Query, error)
17 | // Find(ctx context.Context, id string) (*Query, error)
18 | // FindAll(ctx context.Context) ([]Query, error)
19 | }
20 |
21 | type QueryRepo struct {
22 | repo dbRepo.IDbRepo
23 | }
24 |
25 | // NewCore returns a new instance of *Core
26 | func NewQueryRepo(repo dbRepo.IDbRepo) *QueryRepo {
27 | return &QueryRepo{repo: repo}
28 | }
29 |
30 | func (r *QueryRepo) Create(ctx context.Context, query *models.Query) error {
31 | err := r.repo.Create(ctx, query)
32 | if err != nil {
33 | provider.Logger(ctx).WithError(err).Errorw("query create failed", map[string]interface{}{"query_id": query.ID})
34 | return err
35 | }
36 |
37 | provider.Logger(ctx).Infow("query created", map[string]interface{}{"query_id": query.ID})
38 |
39 | return nil
40 | }
41 |
42 | func (r *QueryRepo) Update(ctx context.Context, query *models.Query) error {
43 | err := r.repo.Update(ctx, query)
44 | if err != nil {
45 | if err == spine.NoRowAffected {
46 | provider.Logger(ctx).Debugw(
47 | "no row affected by query update",
48 | map[string]interface{}{"query_id": query.ID},
49 | )
50 | return nil
51 | }
52 | provider.Logger(ctx).WithError(err).Errorw(
53 | "query update failed",
54 | map[string]interface{}{"query_id": query.ID})
55 | return err
56 | }
57 |
58 | provider.Logger(ctx).Infow("query updated", map[string]interface{}{"query_id": query.ID})
59 |
60 | return nil
61 | }
62 |
63 | func (r *QueryRepo) Find(ctx context.Context, id string) (*models.Query, error) {
64 | query := models.Query{}
65 |
66 | err := r.repo.FindByID(ctx, &query, id)
67 | if err != nil {
68 | return nil, err
69 | }
70 |
71 | return &query, nil
72 | }
73 |
74 | func (r *QueryRepo) FindMany(ctx context.Context, conditions map[string]interface{}) ([]models.Query, error) {
75 | var queries []models.Query
76 |
77 | err := r.repo.FindMany(ctx, &queries, conditions)
78 | if err != nil {
79 | return nil, err
80 | }
81 |
82 | return queries, nil
83 | }
84 |
--------------------------------------------------------------------------------
/internal/monitor/metric.go:
--------------------------------------------------------------------------------
1 | package monitor
2 |
3 | import (
4 | "github.com/prometheus/client_golang/prometheus"
5 | "github.com/prometheus/client_golang/prometheus/promauto"
6 | "github.com/razorpay/trino-gateway/internal/boot"
7 | )
8 |
9 | type Metrics struct {
10 | executionsTotal *prometheus.CounterVec
11 | executionlastRunAt *prometheus.GaugeVec
12 | executionDurations *prometheus.HistogramVec
13 | backendLoad *prometheus.GaugeVec
14 | }
15 |
16 | var metrics *Metrics
17 |
18 | func initMetrics() {
19 | env := boot.Config.App.Env
20 | metrics = &Metrics{}
21 | metrics.executionsTotal = promauto.NewCounterVec(
22 | prometheus.CounterOpts{
23 | Name: "trino_gateway_monitor_executions_total",
24 | Help: "Number of executions triggered for monitor task.",
25 | },
26 | []string{"env"},
27 | ).MustCurryWith(prometheus.Labels{"env": env})
28 |
29 | metrics.executionlastRunAt = promauto.NewGaugeVec(
30 | prometheus.GaugeOpts{
31 | Name: "trino_gateway_monitor_execution_last_run_at",
32 | Help: "Monitor task last run epoch ts.",
33 | },
34 | []string{"env"},
35 | ).MustCurryWith(prometheus.Labels{"env": env})
36 |
37 | metrics.executionDurations = promauto.NewHistogramVec(
38 | prometheus.HistogramOpts{
39 | Name: "trino_gateway_monitor_execution_seconds_histogram",
40 | Help: "Monitor task execution time distributions histogram.",
41 | Buckets: []float64{5, 15, 30, 60, 90, 120, 150, 180, 210, 240},
42 | },
43 | []string{"env"},
44 | ).MustCurryWith(prometheus.Labels{"env": env}).(*prometheus.HistogramVec)
45 |
46 | metrics.backendLoad = promauto.NewGaugeVec(
47 | prometheus.GaugeOpts{
48 | Name: "trino_gateway_monitor_backend_load",
49 | Help: "Backend Load computed by last run of monitor task.",
50 | },
51 | []string{"env", "backend"},
52 | ).MustCurryWith(prometheus.Labels{"env": env})
53 | }
54 |
--------------------------------------------------------------------------------
/internal/monitor/monitor.go:
--------------------------------------------------------------------------------
1 | package monitor
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "sync"
7 | "time"
8 |
9 | "github.com/go-co-op/gocron"
10 | "github.com/razorpay/trino-gateway/internal/provider"
11 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway"
12 | )
13 |
14 | type Monitor struct {
15 | core ICore
16 | }
17 |
18 | func init() {
19 | initMetrics()
20 | }
21 |
22 | func NewMonitor(core ICore) *Monitor {
23 | return &Monitor{
24 | core: core,
25 | }
26 | }
27 |
28 | func (m *Monitor) Schedule(ctx *context.Context, interval string) error {
29 | s := gocron.NewScheduler(time.UTC)
30 | j, err := s.Every(interval).Do(m.Execute, ctx)
31 | if err != nil {
32 | return err
33 | }
34 | s.SetMaxConcurrentJobs(1, gocron.RescheduleMode)
35 | s.StartImmediately().StartAsync()
36 |
37 | provider.Logger(*ctx).Infow("Scheduled Monitoring Job", map[string]interface{}{
38 | "job": fmt.Sprintf("%v", j),
39 | "nextRunUTC": j.NextRun().Local().UTC(),
40 | })
41 |
42 | return nil
43 | }
44 |
45 | func (m *Monitor) Execute(ctx *context.Context) {
46 | provider.Logger(*ctx).Info("Executing monitoring task")
47 |
48 | metrics.executionsTotal.
49 | WithLabelValues().Inc()
50 |
51 | defer func(st time.Time) {
52 | duration := float64(time.Since(st).Seconds())
53 | metrics.executionDurations.
54 | WithLabelValues().Observe(duration)
55 |
56 | metrics.executionlastRunAt.
57 | WithLabelValues().SetToCurrentTime()
58 | }(time.Now())
59 |
60 | provider.Logger(*ctx).Info("Evaluating new state for backends")
61 | newStates, err := m.core.EvaluateBackendNewState(ctx)
62 | if err != nil {
63 | provider.Logger(*ctx).WithError(err).Error("Error evaluating new states for backends")
64 | return
65 | }
66 |
67 | if len(newStates.Healthy) == 0 {
68 | provider.Logger(*ctx).Error("No Backends are in Healthy state.")
69 | }
70 |
71 | provider.Logger(*ctx).Debug("Marking healthy/unhealthy backends as per evaluated states")
72 | var wg sync.WaitGroup
73 | for _, b := range newStates.Unhealthy {
74 | wg.Add(1)
75 | go func(x *gatewayv1.Backend) {
76 | defer wg.Done()
77 | err = m.core.MarkUnhealthyBackend(ctx, x)
78 | if err != nil {
79 | provider.Logger(*ctx).WithError(err).Errorw(
80 | "Failure marking backend as Unhealthy",
81 | map[string]interface{}{"backend": x})
82 | }
83 | }(b)
84 | }
85 |
86 | for _, b := range newStates.Healthy {
87 | wg.Add(1)
88 | go func(x *gatewayv1.Backend) {
89 | defer wg.Done()
90 | err = m.core.MarkHealthyBackend(ctx, x)
91 | if err != nil {
92 | provider.Logger(*ctx).WithError(err).Errorw(
93 | "Failure marking backend as Healthy",
94 | map[string]interface{}{"backend": x})
95 | }
96 | }(b)
97 | }
98 |
99 | // Wait for all backend health updates to complete.
100 | wg.Wait()
101 |
102 | provider.Logger(*ctx).Info("Finished executing monitoring task")
103 | }
104 |
--------------------------------------------------------------------------------
/internal/provider/logger.go:
--------------------------------------------------------------------------------
1 | package provider
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/razorpay/trino-gateway/pkg/logger"
7 | )
8 |
9 | // Logger will provider the logger instance
10 | func Logger(ctx context.Context) *logger.Entry {
11 | ctxLogger, err := logger.Ctx(ctx)
12 | if err == nil {
13 | return ctxLogger
14 | }
15 |
16 | panic(err.Error())
17 | }
18 |
--------------------------------------------------------------------------------
/internal/router/auth.go:
--------------------------------------------------------------------------------
1 | package router
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "fmt"
8 | "io/ioutil"
9 | "net/http"
10 | "time"
11 |
12 | "github.com/razorpay/trino-gateway/internal/boot"
13 | "github.com/razorpay/trino-gateway/internal/provider"
14 | "github.com/razorpay/trino-gateway/internal/router/trinoheaders"
15 | "github.com/razorpay/trino-gateway/internal/utils"
16 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway"
17 | )
18 |
19 | type IAuthService interface {
20 | Authenticate(ctx *context.Context, username string, password string) (bool, error)
21 | }
22 |
23 | type AuthService struct {
24 | ValidationProviderURL string
25 | ValidationProviderToken string
26 | }
27 |
28 | /*
29 | Calls an Auth Token Validator Service with following api contract:
30 | With Params-
31 |
32 | {
33 | "email": "abc@xyz.com",
34 | "token": "token123"
35 | }
36 |
37 | api returns-
38 | If Authenticated - {"ok": true}
39 | If not authenticated- {"ok": false}
40 |
41 | @returns-
42 | boolean{True or False},error_message
43 | */
44 | func (s *AuthService) ValidateFromValidationProvider(ctx *context.Context, username string, password string) (bool, error) {
45 | payload := struct {
46 | Username string `json:"email"`
47 | Token string `json:"token"`
48 | }{
49 | Username: username,
50 | Token: password,
51 | }
52 |
53 | payloadBytes, _ := json.Marshal(payload)
54 | req, _ := http.NewRequest("POST", s.ValidationProviderURL, bytes.NewReader(payloadBytes))
55 | req.Header.Set("X-Auth-Token", s.ValidationProviderToken)
56 | req.Header.Set("Content-Type", "application/json")
57 |
58 | client := &http.Client{}
59 | resp, err := client.Do(req)
60 | if err != nil {
61 | return false, err
62 | }
63 | defer resp.Body.Close()
64 |
65 | respBody, _ := ioutil.ReadAll(resp.Body)
66 |
67 | var data struct {
68 | OK bool `json:"ok"`
69 | }
70 | jsonParseError := json.Unmarshal([]byte(respBody), &data)
71 | if jsonParseError != nil {
72 | return false, jsonParseError
73 | }
74 |
75 | return data.OK, nil
76 | }
77 |
78 | func (s *AuthService) Authenticate(ctx *context.Context, username string, password string) (bool, error) {
79 | authCache := s.GetInMemoryAuthCache(ctx)
80 |
81 | if entry, exists := authCache.Get(username); exists && entry == password {
82 | authCache.Update(username, password)
83 | return true, nil
84 | }
85 |
86 | isValid, err := s.ValidateFromValidationProvider(ctx, username, password)
87 | if err != nil {
88 | return false, err
89 | }
90 |
91 | if isValid {
92 | authCache.Update(username, password)
93 | }
94 |
95 | return isValid, nil
96 | }
97 |
98 | func (s *AuthService) GetInMemoryAuthCache(ctx *context.Context) utils.ISimpleCache {
99 | ctxKeyName := "routerAuthCache"
100 | authCache, ok := (*ctx).Value(ctxKeyName).(*utils.InMemorySimpleCache)
101 |
102 | if !ok {
103 |
104 | expiryInterval, _ := time.ParseDuration(boot.Config.Auth.Router.DelegatedAuth.CacheTTLMinutes)
105 |
106 | authCache = &utils.InMemorySimpleCache{
107 | Cache: make(map[string]struct {
108 | Timestamp time.Time
109 | Value string
110 | }),
111 | ExpiryInterval: expiryInterval,
112 | }
113 | *ctx = context.WithValue(*ctx, ctxKeyName, authCache)
114 | }
115 | return authCache
116 | }
117 |
118 | func (r *RouterServer) isAuthDelegated(ctx *context.Context) (bool, error) {
119 | res, err := r.gatewayApiClient.Policy.EvaluateAuthDelegationForClient(*ctx, &gatewayv1.EvaluateAuthDelegationRequest{IncomingPort: int32(r.port)})
120 |
121 | if err != nil {
122 | provider.Logger(*ctx).WithError(err).Errorw(
123 | fmt.Sprint(LOG_TAG, "Failed to evaluate auth delegation policy. Assuming delegation is disabled."),
124 | map[string]interface{}{
125 | "port": r.port,
126 | })
127 | return false, err
128 | }
129 | return res.GetIsAuthDelegated(), nil
130 | }
131 |
132 | func (r *RouterServer) AuthHandler(ctx *context.Context, h http.Handler) http.Handler {
133 | return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
134 | if isAuth, _ := r.isAuthDelegated(ctx); isAuth {
135 | // TODO: Refactor auth type handling to a dedicated type
136 |
137 | // BasicAuth
138 | username, password, isBasicAuth := req.BasicAuth()
139 |
140 | // CustomAuth
141 | if !isBasicAuth {
142 | provider.Logger(*ctx).Debug("Custom Auth type")
143 | username = trinoheaders.Get(trinoheaders.User, req)
144 | password = trinoheaders.Get(trinoheaders.Password, req)
145 | } else {
146 | if u := trinoheaders.Get(trinoheaders.User, req); u != username {
147 | errorMsg := fmt.Sprintf("Username from basicauth - %s does not match with User principal - %s", username, u)
148 | provider.Logger(*ctx).Debug(errorMsg)
149 | http.Error(w, errorMsg, http.StatusUnauthorized)
150 | }
151 |
152 | // Remove auth details from request
153 | req.Header.Del("Authorization")
154 | }
155 |
156 | // NoAuth
157 | isNoAuth := password == ""
158 | if isNoAuth {
159 | provider.Logger(*ctx).Debug("No Auth type detected")
160 | errorMsg := fmt.Sprintf("Password required")
161 | http.Error(w, errorMsg, http.StatusUnauthorized)
162 | return
163 | }
164 |
165 | isAuthenticated, err := r.authService.Authenticate(ctx, username, password)
166 |
167 | if err != nil {
168 | errorMsg := fmt.Sprintf("Unable to Authenticate users. Getting error - %s", err)
169 | provider.Logger(*ctx).Error(errorMsg)
170 | http.Error(w, "Unable to Authenticate the user", http.StatusNotFound)
171 | return
172 | }
173 | if !isAuthenticated {
174 | provider.Logger(*ctx).Debug(fmt.Sprintf("User - %s not authenticated", username))
175 | http.Error(w, "User not authenticated", http.StatusUnauthorized)
176 | return
177 | }
178 | h.ServeHTTP(w, req)
179 | } else {
180 | h.ServeHTTP(w, req)
181 | }
182 | })
183 | }
184 |
--------------------------------------------------------------------------------
/internal/router/auth_test.go:
--------------------------------------------------------------------------------
1 | package router
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/razorpay/trino-gateway/pkg/logger"
8 | "github.com/stretchr/testify/suite"
9 | )
10 |
11 | // Define the suite, and absorb the built-in basic suite
12 | // functionality from testify - including a T() method which
13 | // returns the current testing context
14 | type AuthSuite struct {
15 | suite.Suite
16 | authService *AuthService
17 | ctx *context.Context
18 | }
19 |
20 | func (suite *AuthSuite) SetupTest() {
21 | lgrConfig := logger.Config{
22 | LogLevel: logger.Warn,
23 | }
24 |
25 | l, err := logger.NewLogger(lgrConfig)
26 | if err != nil {
27 | panic("failed to initialize logger")
28 | }
29 |
30 | c := context.WithValue(context.Background(), logger.LoggerCtxKey, l)
31 |
32 | suite.ctx = &c
33 | suite.authService = &AuthService{}
34 | }
35 |
36 | func (suite *AuthSuite) Test_GetInMemoryAuthCache_Persistance() {
37 | key := "testKey"
38 | value := "testValue"
39 |
40 | authCache := suite.authService.GetInMemoryAuthCache(suite.ctx)
41 | authCache.Update(key, value)
42 |
43 | authCacheInstance2 := suite.authService.GetInMemoryAuthCache(suite.ctx)
44 | entry, exists := authCacheInstance2.Get(key)
45 |
46 | suite.Truef(exists, "Second cache instance doesn't have same key")
47 | if exists {
48 | suite.Equalf(value, entry, "Second Cache instance value doesn't match.")
49 | }
50 |
51 | }
52 | func TestAuthSuite(t *testing.T) {
53 | suite.Run(t, new(AuthSuite))
54 | }
55 |
--------------------------------------------------------------------------------
/internal/router/metric.go:
--------------------------------------------------------------------------------
1 | package router
2 |
3 | import (
4 | "github.com/prometheus/client_golang/prometheus"
5 | "github.com/prometheus/client_golang/prometheus/promauto"
6 | "github.com/razorpay/trino-gateway/internal/boot"
7 | )
8 |
9 | type Metrics struct {
10 | requestsReceivedTotal *prometheus.CounterVec
11 | requestsRoutedTotal *prometheus.CounterVec
12 | requestPreRoutingDelays *prometheus.HistogramVec
13 | requestPostRoutingDelays *prometheus.HistogramVec
14 | responsesSentTotal *prometheus.CounterVec
15 | responseDurations *prometheus.HistogramVec
16 | }
17 |
18 | var metrics *Metrics
19 |
20 | func initMetrics() {
21 | env := boot.Config.App.Env
22 | metrics = &Metrics{}
23 | metrics.requestsReceivedTotal = promauto.NewCounterVec(
24 | prometheus.CounterOpts{
25 | Name: "trino_gateway_router_http_requests_total",
26 | Help: "Number of HTTP requests received from clients.",
27 | },
28 | []string{"env", "method", "port"},
29 | ).MustCurryWith(prometheus.Labels{"env": env})
30 |
31 | metrics.requestsRoutedTotal = promauto.NewCounterVec(
32 | prometheus.CounterOpts{
33 | Name: "trino_gateway_router_http_requests_routed_total",
34 | Help: "Number of HTTP requests routed to a trino server.",
35 | },
36 | []string{"env", "method", "port", "group", "backend"},
37 | ).MustCurryWith(prometheus.Labels{"env": env})
38 |
39 | metrics.requestPreRoutingDelays = promauto.NewHistogramVec(
40 | prometheus.HistogramOpts{
41 | Name: "trino_gateway_router_http_pre_routing_delay_ms_histogram",
42 | Help: "Delay in routing client request to a Trino server, latency distributions histogram.",
43 | Buckets: []float64{5, 10, 15, 20, 30, 40, 60, 100, 150, 500},
44 | },
45 | []string{"env", "method"},
46 | ).MustCurryWith(prometheus.Labels{"env": env}).(*prometheus.HistogramVec)
47 |
48 | metrics.requestPostRoutingDelays = promauto.NewHistogramVec(
49 | prometheus.HistogramOpts{
50 | Name: "trino_gateway_router_http_post_routing_delay_ms_histogram",
51 | Help: "Delay in sending response to client after receiving response from Trino server, latency distributions histogram.",
52 | Buckets: []float64{2, 5, 10, 15, 20, 25, 30, 40, 50, 100, 500},
53 | },
54 | []string{"env", "method", "code"},
55 | ).MustCurryWith(prometheus.Labels{"env": env}).(*prometheus.HistogramVec)
56 |
57 | metrics.responsesSentTotal = promauto.NewCounterVec(
58 | prometheus.CounterOpts{
59 | Name: "trino_gateway_router_http_responses_total",
60 | Help: "Number of HTTP responses sent back to client.",
61 | },
62 | []string{"env", "method", "code"},
63 | ).MustCurryWith(prometheus.Labels{"env": env})
64 |
65 | metrics.responseDurations = promauto.NewHistogramVec(
66 | prometheus.HistogramOpts{
67 | Name: "trino_gateway_router_http_durations_ms_histogram",
68 | Help: "Router HTTP latency distributions histogram for responses sent to clients.",
69 | Buckets: []float64{20, 40, 60, 90, 120, 150, 200, 250, 300, 500},
70 | },
71 | []string{"env", "method", "code"},
72 | ).MustCurryWith(prometheus.Labels{"env": env}).(*prometheus.HistogramVec)
73 | }
74 |
--------------------------------------------------------------------------------
/internal/router/request_test.go:
--------------------------------------------------------------------------------
1 | package router
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/razorpay/trino-gateway/pkg/logger"
8 | "github.com/stretchr/testify/suite"
9 | )
10 |
11 | type HelpersSuite struct {
12 | suite.Suite
13 | ctx *context.Context
14 | }
15 |
16 | func (suite *HelpersSuite) SetupTest() {
17 | lgrConfig := logger.Config{
18 | LogLevel: logger.Warn,
19 | }
20 |
21 | l, err := logger.NewLogger(lgrConfig)
22 | if err != nil {
23 | panic("failed to initialize logger")
24 | }
25 |
26 | c := context.WithValue(context.Background(), logger.LoggerCtxKey, l)
27 |
28 | suite.ctx = &c
29 | }
30 |
31 | func (suite *HelpersSuite) Test_extractQueryId() {
32 | }
33 |
34 | func (suite *HelpersSuite) Test_isValidRequest() {
35 | }
36 |
37 | func (suite *HelpersSuite) Test_constructQueryFromReq() {
38 | }
39 |
40 | func TestSuite(t *testing.T) {
41 | suite.Run(t, new(HelpersSuite))
42 | }
43 |
--------------------------------------------------------------------------------
/internal/router/request_type.go:
--------------------------------------------------------------------------------
1 | package router
2 |
3 | import (
4 | "fmt"
5 |
6 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway"
7 | )
8 |
9 | type ClientRequest interface {
10 | isClientRequest()
11 | Validate() error
12 | }
13 |
14 | type UiRequest struct {
15 | ClientRequest
16 | queryId string
17 | }
18 |
19 | func (UiRequest) isClientRequest() {}
20 |
21 | func (r UiRequest) Validate() error {
22 | if r.queryId == "" {
23 | tag := "ui"
24 | return fmt.Errorf("%s: %s", tag, "Missing query id")
25 | }
26 | return nil
27 | }
28 |
29 | type QueryApiRequest struct {
30 | ClientRequest
31 | headerConnectionProperties string
32 | headerClientTags string
33 | incomingPort int32
34 | transactionId string
35 | clientHost string
36 | Query *gatewayv1.Query
37 | }
38 |
39 | func (QueryApiRequest) isClientRequest() {}
40 | func (r QueryApiRequest) Validate() error {
41 | tag := "query api"
42 | if r.Query.GetUsername() == "" {
43 | return fmt.Errorf("%s: %s", tag, "Missing Trino Username header")
44 | }
45 | if r.Query.GetId() == "" {
46 | return fmt.Errorf("%s: %s", tag, "Missing Query Id")
47 | }
48 |
49 | // TODO: remove it once transaction support is added
50 | // Looker's Presto client sends `X-Presto-Transaction-Id: NONE`
51 | // whereas trino client doesnt send it if its not set
52 | if !(r.transactionId == "" || r.transactionId == "NONE") {
53 | return fmt.Errorf("%s: %s", tag, "Transactions are not supported in gateway.")
54 | }
55 | return nil
56 | }
57 |
58 | type QueryRequest struct {
59 | ClientRequest
60 | headerConnectionProperties string
61 | headerClientTags string
62 | incomingPort int32
63 | transactionId string
64 | clientHost string
65 | Query *gatewayv1.Query
66 | }
67 |
68 | func (QueryRequest) isClientRequest() {}
69 | func (r QueryRequest) Validate() error {
70 | tag := "query submission"
71 | if r.Query.GetUsername() == "" {
72 | return fmt.Errorf("%s: %s", tag, "Missing Trino Username header")
73 | }
74 | if r.Query.GetText() == "" {
75 | return fmt.Errorf("%s: %s", tag, "Missing Query text")
76 | }
77 |
78 | // TODO: remove it once transaction support is added
79 | // Looker's Presto client sends `X-Presto-Transaction-Id: NONE`
80 | // whereas trino client doesnt send it if its not set
81 | if !(r.transactionId == "" || r.transactionId == "NONE") {
82 | return fmt.Errorf("%s: %s", tag, "Transactions are not supported in gateway.")
83 | }
84 | return nil
85 | }
86 |
87 | type ApiRequest struct {
88 | ClientRequest
89 | }
90 |
91 | func (ApiRequest) isClientRequest() {}
92 | func (r ApiRequest) Validate() error {
93 | return nil
94 | }
95 |
--------------------------------------------------------------------------------
/internal/router/response.go:
--------------------------------------------------------------------------------
1 | package router
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "fmt"
7 | "net/http"
8 | "regexp"
9 | "time"
10 |
11 | "github.com/razorpay/trino-gateway/internal/provider"
12 | "github.com/razorpay/trino-gateway/internal/utils"
13 | )
14 |
15 | func (r *RouterServer) handleRedirect(ctx *context.Context, resp *http.Response) error {
16 | // This needs more testing with looker
17 |
18 | // Handle Redirects
19 | // TODO: Clean it up
20 | // TODO - validate its working for all use cases
21 | regex := regexp.MustCompile(`\w+\:\/\/[^\/]*(.*)`)
22 | if resp.Header.Get("Location") != ("") {
23 | oldLoc := resp.Header.Get("Location")
24 | newLoc := fmt.Sprint("http://", r.routerHostname, regex.ReplaceAllString(oldLoc, "$1"))
25 | resp.Header.Set("Location", newLoc)
26 | }
27 | return nil
28 | }
29 |
30 | func (r *RouterServer) ProcessResponse(
31 | ctx *context.Context,
32 | resp *http.Response,
33 | cReq ClientRequest,
34 | ) error {
35 | switch stCode := resp.StatusCode; true {
36 | case stCode >= 200 && stCode < 300:
37 | // TODO - fix redirect
38 | _ = r.handleRedirect(ctx, resp)
39 | case stCode >= 300 && stCode < 400:
40 | // http3xx -> server sent redirection, gateway doesn't need to modify anything here
41 | // Assuming Clients can directly connect to redirected Uri
42 | default:
43 | provider.Logger(*ctx).Errorw(
44 | fmt.Sprint(LOG_TAG, "Routing unsuccessful"),
45 | map[string]interface{}{
46 | "serverResponse": utils.StringifyHttpRequestOrResponse(ctx, resp),
47 | })
48 | return nil
49 | }
50 |
51 | provider.Logger(*ctx).Debug(LOG_TAG + "Routing successful")
52 |
53 | switch nt := cReq.(type) {
54 | case *ApiRequest:
55 | return nil
56 | case *UiRequest:
57 | return nil
58 | case *QueryRequest:
59 | req := nt.Query
60 | body, err := utils.ParseHttpPayloadBody(ctx, &resp.Body, utils.GetHttpBodyEncoding(ctx, resp))
61 | if err != nil {
62 | provider.Logger(*ctx).WithError(err).Error(fmt.Sprint(LOG_TAG, "unable to parse body of server response"))
63 | }
64 |
65 | go func() {
66 | req.Id = extractQueryIdFromServerResponse(ctx, body)
67 | req.SubmittedAt = time.Now().Unix()
68 |
69 | _, err = r.gatewayApiClient.Query.CreateOrUpdateQuery(*ctx, req)
70 | if err != nil {
71 | provider.Logger(
72 | *ctx).WithError(err).Errorw(
73 | fmt.Sprint(LOG_TAG, "Unable to save query"),
74 | map[string]interface{}{
75 | "query_id": req.Id,
76 | })
77 | }
78 | }()
79 |
80 | provider.Logger(*ctx).Debugw("Server Response Processed", map[string]interface{}{
81 | "resp": utils.StringifyHttpRequestOrResponse(ctx, resp),
82 | })
83 |
84 | return nil
85 | default:
86 | return nil
87 | }
88 | }
89 |
90 | func extractQueryIdFromServerResponse(ctx *context.Context, body string) string {
91 | provider.Logger(*ctx).Debugw(fmt.Sprint(LOG_TAG, "extracting queryId from server response"),
92 | map[string]interface{}{
93 | "body": body,
94 | })
95 | var resp struct{ Id string }
96 | json.Unmarshal([]byte(body), &resp)
97 | return resp.Id
98 | }
99 |
--------------------------------------------------------------------------------
/internal/router/trinoheaders/trino.go:
--------------------------------------------------------------------------------
1 | package trinoheaders
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | )
7 |
8 | // https://github.com/trinodb/trino/blob/master/client/trino-client/src/main/java/io/trino/client/ProtocolHeaders.java
9 | // Connection properties is not part of trino client protocol but is sent from some jdbc clients
10 | const (
11 | PreparedStatement = "Prepared-Statement"
12 | User = "User"
13 | ClientTags = "Client-Tags"
14 | ConnectionProperties = "Connection-Properties"
15 | TransactionId = "Transaction-Id"
16 | Password = "Password"
17 | Source = "Source"
18 | )
19 |
20 | var allowedPrefixes = [...]string{"Presto", "Trino"}
21 |
22 | func Get(key string, req *http.Request) string {
23 | for _, h := range allowedPrefixes {
24 | s := fmt.Sprintf("X-%s-%s", h, key)
25 |
26 | if val := req.Header.Get(s); val != "" {
27 | return val
28 | }
29 | }
30 | return ""
31 | }
32 |
--------------------------------------------------------------------------------
/internal/router/trinoheaders/trino_test.go:
--------------------------------------------------------------------------------
1 | package trinoheaders
2 |
3 | import (
4 | "net/http"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func Test_Get(t *testing.T) {
11 | trinoHttpReq := &http.Request{
12 | Header: map[string][]string{
13 | "X-Trino-User": {"user"},
14 | "X-Trino-Connection-Properties": {"connProps"},
15 | },
16 | }
17 | assert.Equal(t, Get("User", trinoHttpReq), "user")
18 | assert.Equal(t, Get("Connection-Properties", trinoHttpReq), "connProps")
19 |
20 | prestoHttpReq := &http.Request{
21 | Header: map[string][]string{
22 | "X-Presto-User": {"user"},
23 | "X-Presto-Connection-Properties": {"connProps"},
24 | },
25 | }
26 | assert.Equal(t, Get("User", prestoHttpReq), "user")
27 | assert.Equal(t, Get("Connection-Properties", prestoHttpReq), "connProps")
28 | }
29 |
--------------------------------------------------------------------------------
/internal/utils/utils.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "bytes"
5 | "compress/gzip"
6 | "context"
7 | "io"
8 | "io/ioutil"
9 | "net/http"
10 | "net/http/httputil"
11 | "sync"
12 | "time"
13 |
14 | "github.com/razorpay/trino-gateway/internal/provider"
15 | "github.com/robfig/cron/v3"
16 | )
17 |
18 | /*
19 | Checks whether provided time object is in 1 minute window of cron expression
20 | */
21 | func IsTimeInCron(ctx *context.Context, t time.Time, sched string) (bool, error) {
22 | s, err := cron.ParseStandard(sched)
23 | if err != nil {
24 | return false, err
25 | }
26 | nextRun := s.Next(t)
27 |
28 | provider.Logger(*ctx).Debugw(
29 | "Evaluated next valid ts from cron expression",
30 | map[string]interface{}{
31 | "providedTime": t,
32 | "nextRun": nextRun,
33 | },
34 | )
35 |
36 | return nextRun.Sub(t).Minutes() <= 1, nil
37 | }
38 |
39 | func SliceContains[T comparable](collection []T, element T) bool {
40 | for _, item := range collection {
41 | if item == element {
42 | return true
43 | }
44 | }
45 |
46 | return false
47 | }
48 |
49 | // Finds intersection of 2 slices
50 | func SimpleSliceIntersection[T comparable](list1 []T, list2 []T) []T {
51 | result := []T{}
52 | seen := map[T]struct{}{}
53 |
54 | for _, elem := range list1 {
55 | seen[elem] = struct{}{}
56 | }
57 |
58 | for _, elem := range list2 {
59 | if _, ok := seen[elem]; ok {
60 | result = append(result, elem)
61 | }
62 | }
63 |
64 | return result
65 | }
66 |
67 | func GetHttpBodyEncoding[T *http.Request | *http.Response](ctx *context.Context, r T) string {
68 | enc := ""
69 | headerKey := "Content-Encoding"
70 | switch v := any(r).(type) {
71 | case *http.Request:
72 | enc = v.Header.Get(headerKey)
73 | case *http.Response:
74 | enc = v.Header.Get(headerKey)
75 | }
76 | return enc
77 | }
78 |
79 | func StringifyHttpRequestOrResponse[T *http.Request | *http.Response](ctx *context.Context, r T) string {
80 | canDumpBody := GetHttpBodyEncoding(ctx, r) == ""
81 | if !canDumpBody {
82 | provider.Logger(*ctx).Debug(
83 | "Encoded body in http payload, assuming binary data and skipping dump of body")
84 | }
85 | var res []byte
86 | var err error
87 | switch v := any(r).(type) {
88 | case *http.Request:
89 | res, err = httputil.DumpRequest(v, canDumpBody)
90 | case *http.Response:
91 | res, err = httputil.DumpResponse(v, canDumpBody)
92 | }
93 | if err != nil {
94 | provider.Logger(*ctx).Errorw(
95 | "Unable to stringify http payload",
96 | map[string]interface{}{
97 | "error": err.Error(),
98 | })
99 | }
100 | return string(res)
101 | }
102 |
103 | func ParseHttpPayloadBody(ctx *context.Context, body *io.ReadCloser, encoding string) (string, error) {
104 | bodyBytes, err := io.ReadAll(*body)
105 | if err != nil {
106 | return "", err
107 | }
108 | // since its a ReadCloser type, the stream will be empty after its read once
109 | // ensure it is restored in original object
110 | *body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
111 |
112 | switch encoding {
113 | case "gzip":
114 | var reader io.ReadCloser
115 | reader, err = gzip.NewReader(bytes.NewReader([]byte(bodyBytes)))
116 | if err != nil {
117 | provider.Logger(
118 | *ctx).WithError(err).Error(
119 | "Unable to decompress gzip encoded response")
120 | }
121 | defer reader.Close()
122 | bb, err := io.ReadAll(reader)
123 | if err != nil {
124 | return "", err
125 | }
126 |
127 | return string(bb), nil
128 | default:
129 | return string(bodyBytes), nil
130 | }
131 | }
132 |
133 | type ISimpleCache interface {
134 | Get(key string) (string, bool)
135 | Update(key, value string)
136 | }
137 |
138 | type InMemorySimpleCache struct {
139 | Cache map[string]struct {
140 | Timestamp time.Time
141 | Value string
142 | }
143 | ExpiryInterval time.Duration
144 | mu sync.Mutex
145 | }
146 |
147 | func (authCache *InMemorySimpleCache) Get(key string) (string, bool) {
148 | authCache.mu.Lock()
149 | defer authCache.mu.Unlock()
150 |
151 | entry, found := authCache.Cache[key]
152 |
153 | if !found {
154 | return "", false
155 | }
156 |
157 | if authCache.ExpiryInterval > 0 &&
158 | time.Since(entry.Timestamp) > authCache.ExpiryInterval {
159 | // If entry is older than cachedDuration, then delete the record and return false
160 | delete(authCache.Cache, key)
161 | return "", false
162 | }
163 | return entry.Value, true
164 | }
165 |
166 | func (authCache *InMemorySimpleCache) Update(key, value string) {
167 | authCache.mu.Lock()
168 | defer authCache.mu.Unlock()
169 |
170 | authCache.Cache[key] = struct {
171 | Timestamp time.Time
172 | Value string
173 | }{
174 | Timestamp: time.Now(),
175 | Value: value,
176 | }
177 | }
178 |
--------------------------------------------------------------------------------
/internal/utils/utils_test.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "bytes"
5 | "compress/gzip"
6 | "context"
7 | "io"
8 | "strings"
9 | "testing"
10 | "time"
11 |
12 | "github.com/razorpay/trino-gateway/pkg/logger"
13 | "github.com/stretchr/testify/suite"
14 | )
15 |
16 | // Define the suite, and absorb the built-in basic suite
17 | // functionality from testify - including a T() method which
18 | // returns the current testing context
19 | type UtilsSuite struct {
20 | suite.Suite
21 | ctx *context.Context
22 | }
23 |
24 | func (suite *UtilsSuite) SetupTest() {
25 | lgrConfig := logger.Config{
26 | LogLevel: logger.Warn,
27 | }
28 |
29 | l, err := logger.NewLogger(lgrConfig)
30 | if err != nil {
31 | panic("failed to initialize logger")
32 | }
33 |
34 | c := context.WithValue(context.Background(), logger.LoggerCtxKey, l)
35 |
36 | suite.ctx = &c
37 | }
38 |
39 | func (suite *UtilsSuite) Test_IsTimeInCron() {
40 | // func (c *Core) isCurrentTimeInCron(ctx *context.Context, sched string) (bool, error)
41 |
42 | tst := func(t time.Time, sched string) bool {
43 | s, _ := IsTimeInCron(suite.ctx, t, sched)
44 | return s
45 | }
46 |
47 | suite.Equalf(
48 | true,
49 | tst(time.Now(), "* * * * *"),
50 | "Failure",
51 | )
52 | }
53 |
54 | func (suite *UtilsSuite) Test_getHttpBodyEncoding() {
55 | }
56 |
57 | func (suite *UtilsSuite) Test_stringifyHttpRequestOrResponse() {
58 | }
59 |
60 | func (suite *UtilsSuite) Test_parseBody() {
61 | str := "body"
62 | stringReader := strings.NewReader(str)
63 | stringReadCloser := io.NopCloser(stringReader)
64 |
65 | tst := func() string {
66 | s, _ := ParseHttpPayloadBody(suite.ctx, &stringReadCloser, "")
67 | return s
68 | }
69 |
70 | suite.Equalf(str, tst(), "Failed to extract string from body")
71 | suite.Equalf(str, tst(), "String extraction is not idempotent")
72 |
73 | var b bytes.Buffer
74 | gz := gzip.NewWriter(&b)
75 | if _, err := gz.Write([]byte(str)); err != nil {
76 | panic(err)
77 | }
78 | if err := gz.Close(); err != nil {
79 | panic(err)
80 | }
81 |
82 | strGzipped := b.String()
83 | stringReaderGzipped := strings.NewReader(strGzipped)
84 | stringReadCloserGzipped := io.NopCloser(stringReaderGzipped)
85 |
86 | tst_gzipped := func() string {
87 | s, _ := ParseHttpPayloadBody(suite.ctx, &stringReadCloserGzipped, "gzip")
88 | return s
89 | }
90 |
91 | suite.Equalf(str, tst_gzipped(), "Failed to extract string from body")
92 | suite.Equalf(str, tst_gzipped(), "String extraction is not idempotent")
93 | }
94 |
95 | func (suite *UtilsSuite) Test_InMemorySimpleCache_Get() {
96 | authCache := &InMemorySimpleCache{
97 | Cache: make(map[string]struct {
98 | Timestamp time.Time
99 | Value string
100 | }),
101 | }
102 | key := "testKey"
103 | value := "testValue"
104 | authCache.Cache[key] = struct {
105 | Timestamp time.Time
106 | Value string
107 | }{
108 | Timestamp: time.Now(),
109 | Value: value,
110 | }
111 |
112 | entry, exists := authCache.Get(key)
113 | suite.Truef(exists, "Entry not found in cache.")
114 | if exists {
115 | suite.Equalf(value, entry, "Cached value doesn't match.")
116 | }
117 | }
118 |
119 | func (suite *UtilsSuite) Test_InMemorySimpleCache_Get_InfiniteExpiry() {
120 | authCache := &InMemorySimpleCache{
121 | Cache: make(map[string]struct {
122 | Timestamp time.Time
123 | Value string
124 | }),
125 | ExpiryInterval: 0 * time.Second,
126 | }
127 | key := "testKey"
128 | value := "testValue"
129 | authCache.Cache[key] = struct {
130 | Timestamp time.Time
131 | Value string
132 | }{
133 | Timestamp: time.Now().Add(-1000 * time.Hour),
134 | Value: value,
135 | }
136 |
137 | entry, exists := authCache.Get(key)
138 | suite.Truef(exists, "Entry not found in cache.")
139 | if exists {
140 | suite.Equalf(value, entry, "Cached value doesn't match.")
141 | }
142 | }
143 |
144 | func (suite *UtilsSuite) Test_InMemorySimpleCache_Get_Expired() {
145 | expiryInterval := 2 * time.Second
146 | authCache := &InMemorySimpleCache{
147 | Cache: make(map[string]struct {
148 | Timestamp time.Time
149 | Value string
150 | }),
151 | ExpiryInterval: expiryInterval,
152 | }
153 | key := "testKey"
154 | value := "testValue"
155 | authCache.Cache[key] = struct {
156 | Timestamp time.Time
157 | Value string
158 | }{
159 | Timestamp: time.Now().Add(-1 * expiryInterval).Add(-1 * time.Second),
160 | Value: value,
161 | }
162 |
163 | _, exists := authCache.Get(key)
164 | suite.False(exists, "Entry not expired.")
165 | }
166 |
167 | func (suite *UtilsSuite) Test_InMemorySimpleCache_Update() {
168 | authCache := &InMemorySimpleCache{
169 | Cache: make(map[string]struct {
170 | Timestamp time.Time
171 | Value string
172 | }),
173 | }
174 | key := "testKey"
175 | value := "testValue"
176 | authCache.Update(key, value)
177 |
178 | entry, exists := authCache.Cache[key]
179 | suite.Truef(exists, "Entry not found in cache.")
180 | if exists {
181 | suite.Equalf(value, entry.Value, "Cached value doesn't match.")
182 | }
183 | }
184 |
185 | func TestSuite(t *testing.T) {
186 | suite.Run(t, new(UtilsSuite))
187 | }
188 |
--------------------------------------------------------------------------------
/pkg/config/config.go:
--------------------------------------------------------------------------------
1 | // Package config has specific primitives for loading application configurations.
2 | //
3 | // Primitives:
4 | // - Application should have struct for containing configuration. E.g. refer
5 | // trino-gateway/internal/config/config.go file.
6 | // - Application should have a directory holding default file and environment
7 | // specific file. E.g. refer trino-gateway/configs/* directory.
8 | //
9 | // Usage:
10 | // - E.g. NewDefaultConfig().Load("dev", &config), where config is a struct
11 | // where configuration gets unmarshalled into.
12 | package config
13 |
14 | import (
15 | "os"
16 | "path"
17 | "runtime"
18 | "strings"
19 |
20 | "github.com/spf13/viper"
21 | )
22 |
23 | // Default options for configuration loading.
24 | const (
25 | DefaultConfigType = "toml"
26 | DefaultConfigDir = "./config"
27 | DefaultConfigFileName = "default"
28 | WorkDirEnv = "WORKDIR"
29 | )
30 |
31 | // Options is config options.
32 | type Options struct {
33 | configType string
34 | configPath string
35 | defaultConfigFileName string
36 | }
37 |
38 | // Config is a wrapper over a underlying config loader implementation.
39 | type Config struct {
40 | opts Options
41 | viper *viper.Viper
42 | }
43 |
44 | // NewDefaultOptions returns default options.
45 | // DISCLAIMER: This function is a bit hacky
46 | // This function expects an env $WORKDIR to
47 | // be set and reads configs from $WORKDIR/configs.
48 | // If $WORKDIR is not set. It uses the absolute path wrt
49 | // the location of this file (config.go) to set configPath
50 | // to 2 levels up in viper (../../configs).
51 | // This function breaks if :
52 | // 1. $WORKDIR is set and configs dir not present in $WORKDIR
53 | // 2. $WORKDIR is not set and ../../configs is not present
54 | // 3. $WORKDIR is not set and runtime absolute path of configs
55 | // is different than build time path as runtime.Caller() evaluates
56 | // only at build time
57 | func NewDefaultOptions() Options {
58 | var configPath string
59 | workDir := os.Getenv(WorkDirEnv)
60 | if workDir != "" {
61 | configPath = path.Join(workDir, DefaultConfigDir)
62 | } else {
63 | _, thisFile, _, _ := runtime.Caller(1)
64 | configPath = path.Join(path.Dir(thisFile), "../../"+DefaultConfigDir)
65 | }
66 | return NewOptions(DefaultConfigType, configPath, DefaultConfigFileName)
67 | }
68 |
69 | // NewOptions returns new Options struct.
70 | func NewOptions(configType string, configPath string, defaultConfigFileName string) Options {
71 | return Options{configType, configPath, defaultConfigFileName}
72 | }
73 |
74 | // NewDefaultConfig returns new config struct with default options.
75 | func NewDefaultConfig() *Config {
76 | return NewConfig(NewDefaultOptions())
77 | }
78 |
79 | // NewConfig returns new config struct.
80 | func NewConfig(opts Options) *Config {
81 | return &Config{opts, viper.New()}
82 | }
83 |
84 | // Load reads environment specific configurations and along with the defaults
85 | // unmarshalls into config.
86 | func (c *Config) Load(env string, config interface{}) error {
87 | if err := c.loadByConfigName(c.opts.defaultConfigFileName, config); err != nil {
88 | return err
89 | }
90 | return c.loadByConfigName(env, config)
91 | }
92 |
93 | // loadByConfigName reads configuration from file and unmarshalls into config.
94 | func (c *Config) loadByConfigName(configName string, config interface{}) error {
95 | c.viper.SetEnvPrefix(strings.ToUpper("trino-gateway"))
96 | c.viper.SetConfigName(configName)
97 | c.viper.SetConfigType(c.opts.configType)
98 | c.viper.AddConfigPath(c.opts.configPath)
99 | c.viper.AutomaticEnv()
100 | c.viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
101 | if err := c.viper.ReadInConfig(); err != nil {
102 | return err
103 | }
104 | return c.viper.Unmarshal(config)
105 | }
106 |
--------------------------------------------------------------------------------
/pkg/config/config_test.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "os"
5 | "strings"
6 | "testing"
7 | "time"
8 |
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | type TestConfig struct {
13 | Title string
14 | Db TestDbConfig
15 | }
16 |
17 | type TestDbConfig struct {
18 | Dialect string
19 | Protocol string
20 | Host string
21 | Port int
22 | Username string
23 | Password string
24 | SslMode string
25 | Name string
26 | MaxOpenConnections int
27 | MaxIdleConnections int
28 | ConnectionMaxLifetime time.Duration
29 | }
30 |
31 | func TestLoadConfig(t *testing.T) {
32 | var c TestConfig
33 |
34 | key := strings.ToUpper("trino-gateway") + "_DB_PASSWORD"
35 | os.Setenv(key, "envpass")
36 | err := NewConfig(NewOptions("toml", "./testdata", "default")).Load("default", &c)
37 | assert.Nil(t, err)
38 | // Asserts that default value exists.
39 | assert.Equal(t, "mysql", c.Db.Dialect)
40 | // Asserts that application environment specific value got overridden.
41 | assert.Equal(t, 10, c.Db.MaxOpenConnections)
42 | // Asserts that environment variable was honored.
43 | assert.Equal(t, "envpass", c.Db.Password)
44 | }
45 |
--------------------------------------------------------------------------------
/pkg/config/testdata/default.toml:
--------------------------------------------------------------------------------
1 | title = "Default config file for testing."
2 |
3 | [db]
4 | dialect = "mysql"
5 | protocol = "tcp"
6 | host = "localhost"
7 | port = 3306
8 | username = "trino-gateway"
9 | password = "trino-gateway"
10 | sslMode = "require"
11 | name = "trino-gateway"
12 | maxOpenConnections = 10
13 | maxIdleConnections = 10
14 | connectionMaxLifetime = 0
15 |
--------------------------------------------------------------------------------
/pkg/fetcher/fetcher.go:
--------------------------------------------------------------------------------
1 | package fetcher
2 |
3 | const MaxLimit = 2000
4 |
5 | // Request for fetching multiple entities
6 | type FetchMultipleRequest struct {
7 | EntityName string
8 | Filter map[string]interface{}
9 | Pagination Pagination
10 | TimeRange TimeRange
11 | IsTrashed bool
12 | HasCreatedAt bool
13 | }
14 |
15 | type IFetchMultipleRequest interface {
16 | GetEntityName() string
17 | GetFilter() map[string]interface{}
18 | GetPagination() IPagination
19 | GetTimeRange() ITimeRange
20 | GetTrashed() bool
21 | ContainsCreatedAt() bool
22 | }
23 |
24 | // GetEntityName : name of the entity to be fetched.
25 | func (fr FetchMultipleRequest) GetEntityName() string {
26 | return fr.EntityName
27 | }
28 |
29 | // GetFilter : filter conditions for fetching the entities.
30 | func (fr FetchMultipleRequest) GetFilter() map[string]interface{} {
31 | return fr.Filter
32 | }
33 |
34 | // GetPagination : skip and limit values for multiple fetch.
35 | func (fr FetchMultipleRequest) GetPagination() IPagination {
36 | return fr.Pagination
37 | }
38 |
39 | // GetTimeRange : time range for multiple fetch query.
40 | func (fr FetchMultipleRequest) GetTimeRange() ITimeRange {
41 | return fr.TimeRange
42 | }
43 |
44 | // GetTrashed : fetch soft deleted entities too if true.
45 | func (fr FetchMultipleRequest) GetTrashed() bool {
46 | return fr.IsTrashed
47 | }
48 |
49 | // ContainsCreatedAt: account_balance table doesn't have create_at field in capital_loc..
50 | func (fr FetchMultipleRequest) ContainsCreatedAt() bool {
51 | return fr.HasCreatedAt
52 | }
53 |
54 | // skip and limit values for multiple fetch.
55 | type Pagination struct {
56 | Limit int
57 | Skip int
58 | }
59 |
60 | type IPagination interface {
61 | GetLimit() int
62 | GetOffset() int
63 | }
64 |
65 | // GetLimit : get limit value for pagination.
66 | func (p Pagination) GetLimit() int {
67 | if p.Limit > MaxLimit {
68 | return MaxLimit
69 | }
70 | return p.Limit
71 | }
72 |
73 | // GetOffset : get skip value for pagination.
74 | func (p Pagination) GetOffset() int {
75 | return p.Skip
76 | }
77 |
78 | // Time range for the entity fetch request.
79 | type TimeRange struct {
80 | From int64
81 | To int64
82 | }
83 |
84 | type ITimeRange interface {
85 | GetTo() int64
86 | GetFrom() int64
87 | }
88 |
89 | // GetFrom : from timestamp.
90 | func (tr TimeRange) GetFrom() int64 {
91 | return tr.From
92 | }
93 |
94 | // GetTo : to timestamp.
95 | func (tr TimeRange) GetTo() int64 {
96 | return tr.To
97 | }
98 |
99 | // Response for multiple fetch request.
100 | type FetchMultipleResponse struct {
101 | entities map[string]interface{}
102 | }
103 |
104 | type IFetchMultipleResponse interface {
105 | GetEntities() interface{}
106 | }
107 |
108 | // GetEntities : returns the fetched entities.
109 | func (fr FetchMultipleResponse) GetEntities() interface{} {
110 | return fr.entities
111 | }
112 |
113 | // Single entity fetch request.
114 | type FetchRequest struct {
115 | EntityName string
116 | ID string
117 | IsTrashed bool
118 | HasCreatedAt bool
119 | }
120 |
121 | type IFetchRequest interface {
122 | GetID() string
123 | GetTrashed() bool
124 | GetEntityName() string
125 | }
126 |
127 | // GetID : Id of the entity to be fetched.
128 | func (fr FetchRequest) GetID() string {
129 | return fr.ID
130 | }
131 |
132 | // GetEntityName : Name of the entity to be fetched.
133 | func (fr FetchRequest) GetEntityName() string {
134 | return fr.EntityName
135 | }
136 |
137 | // GetTrashed : Include soft deleted records if true.
138 | func (fr FetchRequest) GetTrashed() bool {
139 | return fr.IsTrashed
140 | }
141 |
142 | type FetchResponse struct {
143 | entity interface{}
144 | }
145 |
146 | // Response for single entity fetch
147 | type IFetchResponse interface {
148 | GetEntity() interface{}
149 | }
150 |
151 | // Return fetched entity.
152 | func (fr FetchResponse) GetEntity() interface{} {
153 | return fr.entity
154 | }
155 |
--------------------------------------------------------------------------------
/pkg/fetcher/manager.go:
--------------------------------------------------------------------------------
1 | package fetcher
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "reflect"
7 | "sync"
8 |
9 | "gorm.io/gorm"
10 | )
11 |
12 | // Client for fetcher package.
13 | type client struct {
14 | db *gorm.DB
15 | entities entityList
16 | }
17 |
18 | // IClient exposed methods of fetcher package.
19 | type IClient interface {
20 | Register(name string, entity IModel, entities interface{}) error
21 | IsEntityRegistered(name string) bool
22 | GetEntityList() []string
23 | Fetch(ctx context.Context, req IFetchRequest) (IFetchResponse, error)
24 | FetchMultiple(ctx context.Context, req IFetchMultipleRequest) (IFetchMultipleResponse, error)
25 | }
26 |
27 | // The entities to be fetched should implement IModel.
28 | type IModel interface {
29 | TableName() string
30 | }
31 |
32 | type entityType struct {
33 | model IModel
34 | models interface{}
35 | }
36 |
37 | type entityList struct {
38 | sync.Mutex
39 | list map[string]entityType
40 | }
41 |
42 | // New : returns a new fetcher client.
43 | func New(db *gorm.DB) IClient {
44 | return &client{
45 | db: db,
46 | entities: entityList{
47 | list: make(map[string]entityType),
48 | },
49 | }
50 | }
51 |
52 | // Register : register the entity with the fetcher.
53 | // An entity can be registered only once. Only the
54 | // entities registered can be fetched using fetcher client.
55 | func (c *client) Register(name string, model IModel, models interface{}) error {
56 | c.entities.Lock()
57 | defer c.entities.Unlock()
58 |
59 | if _, ok := c.entities.list[name]; ok {
60 | return fmt.Errorf("entity %s name already registered", name)
61 | }
62 | c.entities.list[name] = entityType{
63 | model: model,
64 | models: models,
65 | }
66 | return nil
67 | }
68 |
69 | func (c *client) IsEntityRegistered(name string) bool {
70 | if _, ok := c.entities.list[name]; ok {
71 | return true
72 | }
73 | return false
74 | }
75 |
76 | // GetEntityList : returns the list of entities registered with fetcher package.
77 | func (c *client) GetEntityList() []string {
78 | list := make([]string, 0, len(c.entities.list))
79 |
80 | for k := range c.entities.list {
81 | list = append(list, k)
82 | }
83 |
84 | return list
85 | }
86 |
87 | // FetchMultiple : returns multiple record of the entities based on the fetch multiple request parameters.
88 | func (c *client) FetchMultiple(_ context.Context, req IFetchMultipleRequest) (IFetchMultipleResponse, error) {
89 | var (
90 | dataTypes entityType
91 | ok bool
92 | )
93 |
94 | if dataTypes, ok = c.entities.list[req.GetEntityName()]; !ok {
95 | return nil, fmt.Errorf("entity %s not registered with fetcher", req.GetEntityName())
96 | }
97 |
98 | models := clone(dataTypes.models)
99 | var query *gorm.DB
100 | if req.ContainsCreatedAt() {
101 | query = c.db.
102 | Order("created_at DESC").
103 | Limit(req.GetPagination().GetLimit()).
104 | Offset(req.GetPagination().GetOffset())
105 |
106 | if req.GetTimeRange().GetFrom() != 0 && req.GetTimeRange().GetTo() != 0 {
107 | query = query.Where("created_at between ? and ?", req.GetTimeRange().GetFrom(), req.GetTimeRange().GetTo())
108 | }
109 | } else {
110 | query = c.db.
111 | Limit(req.GetPagination().GetLimit()).
112 | Offset(req.GetPagination().GetOffset())
113 | }
114 |
115 | if req.GetTrashed() {
116 | query = query.Unscoped()
117 | }
118 |
119 | if len(req.GetFilter()) >= 1 {
120 | query = query.Where(req.GetFilter())
121 | }
122 | query = query.Find(models)
123 | return FetchMultipleResponse{
124 | entities: map[string]interface{}{
125 | req.GetEntityName(): models,
126 | },
127 | }, query.Error
128 | }
129 |
130 | // Fetch : returns single entity from the using id.
131 | func (c *client) Fetch(_ context.Context, req IFetchRequest) (IFetchResponse, error) {
132 | var (
133 | dataTypes entityType
134 | ok bool
135 | )
136 |
137 | if dataTypes, ok = c.entities.list[req.GetEntityName()]; !ok {
138 | return nil, fmt.Errorf("entity %s not registered with fetcher", req.GetEntityName())
139 | }
140 |
141 | model := clone(dataTypes.model)
142 | query := c.db.Where("id = ?", req.GetID())
143 |
144 | if req.GetTrashed() {
145 | query = query.Unscoped()
146 | }
147 |
148 | query = query.First(model)
149 | return FetchResponse{
150 | entity: model,
151 | }, query.Error
152 | }
153 |
154 | func clone(src interface{}) interface{} {
155 | return reflect.
156 | New(reflect.TypeOf(src).Elem()).
157 | Interface()
158 | }
159 |
--------------------------------------------------------------------------------
/pkg/logger/entry_test.go:
--------------------------------------------------------------------------------
1 | package logger_test
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "testing"
7 |
8 | "github.com/razorpay/trino-gateway/pkg/logger"
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | func TestEntryWithError(t *testing.T) {
13 | assert := assert.New(t)
14 |
15 | defer func() {
16 | logger.ErrorKey = "error"
17 | }()
18 |
19 | err := fmt.Errorf("doomed here %d", 1234)
20 |
21 | config := logger.Config{
22 | LogLevel: logger.Debug,
23 | }
24 |
25 | lgr, lErr := logger.NewLogger(config)
26 |
27 | assert.Nil(lErr)
28 |
29 | entry := logger.NewEntry(lgr)
30 |
31 | assert.Equal(err.Error(), entry.WithError(err).Data["error"])
32 |
33 | logger.ErrorKey = "err"
34 | assert.Equal(err.Error(), entry.WithError(err).Data["err"])
35 | }
36 |
37 | func TestEntryWithContext(t *testing.T) {
38 | assert := assert.New(t)
39 | ctx := context.WithValue(context.Background(), "foo", "bar")
40 |
41 | config := logger.Config{
42 | LogLevel: logger.Debug,
43 | }
44 |
45 | lgr, err := logger.NewLogger(config)
46 |
47 | assert.Nil(err)
48 |
49 | entry := logger.NewEntry(lgr)
50 |
51 | assert.Equal("bar", entry.WithContext(ctx, []string{"foo"}).Data["foo"])
52 | }
53 |
--------------------------------------------------------------------------------
/pkg/logger/logger_test.go:
--------------------------------------------------------------------------------
1 | package logger_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/razorpay/trino-gateway/pkg/logger"
7 | )
8 |
9 | func BenchmarkNewLogger(b *testing.B) {
10 | config := logger.Config{
11 | LogLevel: logger.Debug,
12 | }
13 | lgr, err := logger.NewLogger(config)
14 | if err != nil {
15 | b.Errorf("NewLogger :%v", err)
16 | }
17 | /*zt := reflect.TypeOf(&(logger.ZapLogger{}))
18 | if !reflect.DeepEqual(lg, zt) {
19 | t.Errorf("NewLogger() = %v, want %v", lg, zt)
20 | }*/
21 | defaultLogger := lgr.WithFields(map[string]interface{}{"key1": "value1"})
22 | for n := 0; n < 1000; n++ {
23 | b.Run("defaultlogger", func(b *testing.B) {
24 | defaultLogger.Debug("I have the default values for Key1 and Value1")
25 | })
26 | }
27 | }
28 |
29 | func LogTest(b *testing.B) {
30 | config := logger.Config{
31 | LogLevel: logger.Debug,
32 | }
33 | lgr, err := logger.NewLogger(config)
34 | if err != nil {
35 | b.Errorf("NewLogger :%v", err)
36 | }
37 | defaultLogger := lgr.WithFields(map[string]interface{}{"key1": "value1"})
38 | for n := 0; n < b.N; n++ {
39 | defaultLogger.Debug("I have the default values for Key1 and Value1")
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/pkg/spine/datatype/validation.go:
--------------------------------------------------------------------------------
1 | package datatype
2 |
3 | import (
4 | "encoding/json"
5 | "errors"
6 | "fmt"
7 | "regexp"
8 | "strconv"
9 |
10 | validation "github.com/go-ozzo/ozzo-validation/v4"
11 | "github.com/gobuffalo/nulls"
12 | )
13 |
14 | const (
15 | // regular expression to validation rzp standard ID
16 | RegexRZPID = `^[a-zA-Z0-9]{14}$`
17 |
18 | // regular expression to validation unix timestamp
19 | RegexUnixTimestamp = `^([\d]{10}|0)$`
20 |
21 | // regular expression to validation basic string
22 | // consisting alpha, number, _, - and white space
23 | RegexBasicString = `^[a-zA-Z0-9_\-\s]*$`
24 | )
25 |
26 | // ValidateNullableInt64 checks NullableInt64 against the rules provided
27 | func ValidateNullableInt64(value nulls.Int64, rules ...validation.RuleFunc) validation.RuleFunc {
28 | return func(value interface{}) error {
29 | v := value.(nulls.Int64)
30 | for _, f := range rules {
31 | if err := f(v.Int64); err != nil {
32 | return err
33 | }
34 | }
35 |
36 | return nil
37 | }
38 | }
39 |
40 | // ValidateNullableString checks NullableString against the rules provided
41 | func ValidateNullableString(value nulls.String, rules ...validation.RuleFunc) validation.RuleFunc {
42 | return func(value interface{}) error {
43 | v := value.(nulls.String)
44 | for _, f := range rules {
45 | if err := f(v.String); err != nil {
46 | return err
47 | }
48 | }
49 |
50 | return nil
51 | }
52 | }
53 |
54 | // IsRZPID checks the given string is a valid 14-char Razorpay ID
55 | func IsRZPID(value interface{}) error {
56 | return isValidString(value, RegexRZPID)
57 | }
58 |
59 | // IsTimestamp will validate if the value is a valid unix timestamp or not
60 | func IsTimestamp(value interface{}) error {
61 | if value == nil {
62 | return nil
63 | }
64 |
65 | return MatchRegex(fmt.Sprintf("%v", value), RegexUnixTimestamp)
66 | }
67 |
68 | // MatchRegex checks if given input matches a given regex or not
69 | func MatchRegex(value string, regex string) error {
70 | if validString, err := regexp.Compile(regex); err != nil {
71 | return errors.New("invalid regex")
72 | } else if !validString.MatchString(value) {
73 | return errors.New("not a valid input")
74 | }
75 |
76 | return nil
77 | }
78 |
79 | // isValidString checks if given input matches a given regex or not
80 | func isValidString(value interface{}, regex string) error {
81 | // let the nil handled by required validation
82 | if value == nil {
83 | return nil
84 | }
85 |
86 | if str, err := isString(value); err != nil {
87 | return err
88 | } else if str == "" {
89 | return nil
90 | } else {
91 | return MatchRegex(str, regex)
92 | }
93 | }
94 |
95 | // isString checks if the given data is valid string or not
96 | func isString(value interface{}) (string, error) {
97 | if str, ok := value.(string); !ok {
98 | return "", errors.New("must be a string")
99 | } else {
100 | return str, nil
101 | }
102 | }
103 |
104 | // IsBasicString validates if the given value is basic string
105 | // consisting alphabet, number, _, - and white space
106 | func IsBasicString(value interface{}) error {
107 | return isValidString(value, RegexBasicString)
108 | }
109 |
110 | // IsInt64 checks if the given data is valid int64 or not
111 | func IsInt64(value interface{}) error {
112 | if _, ok := value.(int64); !ok {
113 | return errors.New("must be an int64 integer")
114 | }
115 | return nil
116 | }
117 |
118 | // IsJson validates if the given value is valid json
119 | func IsJson(value interface{}) error {
120 | if value == nil {
121 | return nil
122 | }
123 |
124 | if str, err := isString(value); err != nil {
125 | return err
126 | } else {
127 | input := []byte(str)
128 | var x struct{}
129 | if err := json.Unmarshal(input, &x); err != nil {
130 | return err
131 | }
132 | }
133 |
134 | return nil
135 | }
136 |
137 | // IsNumeric checks if the given string 's' is float, int, signed / unsigned, exponential
138 | // and returns true if it is valid ..
139 | func IsNumeric(value interface{}) error {
140 | valueStr, err := isString(value)
141 | if err != nil {
142 | return err
143 | }
144 | _, err = strconv.ParseFloat(valueStr, 64)
145 |
146 | return err
147 | }
148 |
149 | // IsBool checks if the given data is boolean or not
150 | func IsBool(value interface{}) error {
151 | switch value.(type) {
152 | case nil:
153 | return nil
154 | case bool:
155 | return nil
156 | default:
157 | return errors.New("must be a boolean")
158 | }
159 | }
160 |
--------------------------------------------------------------------------------
/pkg/spine/db/db_test.go:
--------------------------------------------------------------------------------
1 | package db_test
2 |
3 | import (
4 | "context"
5 | "regexp"
6 | "testing"
7 | "time"
8 |
9 | "gorm.io/plugin/dbresolver"
10 |
11 | "github.com/DATA-DOG/go-sqlmock"
12 | "github.com/stretchr/testify/assert"
13 | "gorm.io/driver/mysql"
14 | "gorm.io/gorm"
15 |
16 | "github.com/razorpay/trino-gateway/pkg/spine"
17 | "github.com/razorpay/trino-gateway/pkg/spine/db"
18 | )
19 |
20 | type TestModel struct {
21 | spine.Model
22 | Name string `json:"name"`
23 | }
24 |
25 | func (t *TestModel) EntityName() string {
26 | return "model"
27 | }
28 |
29 | func (t *TestModel) TableName() string {
30 | return "model"
31 | }
32 |
33 | func (t *TestModel) GetID() string {
34 | return t.ID
35 | }
36 |
37 | func (t *TestModel) Validate() error {
38 | return nil
39 | }
40 |
41 | func (t *TestModel) SetDefaults() error {
42 | return nil
43 | }
44 |
45 | func TestGetConnectionPath(t *testing.T) {
46 | c := getDefaultConfig()
47 | // Asserts connection string for mysql dialect.
48 | assert.Equal(t, "user:pass@tcp(localhost:3307)/database?charset=utf8&parseTime=True&loc=Local", c.GetConnectionPath())
49 | // Asserts connection string for postgres dialect.
50 | c.Dialect = "postgres"
51 | assert.Equal(t, "host=localhost port=3307 dbname=database sslmode=require user=user password=pass", c.GetConnectionPath())
52 |
53 | // invalid dialect
54 | c.Dialect = "invalid"
55 | assert.Equal(t, "", c.GetConnectionPath())
56 | }
57 |
58 | func TestNewDb(t *testing.T) {
59 | tests := []struct {
60 | name string
61 | err string
62 | config db.IConfigReader
63 | options func() ([]func(*db.DB) error, func())
64 | }{
65 | {
66 | name: "success",
67 | config: getDefaultConfig(),
68 | options: func() ([]func(*db.DB) error, func()) {
69 | conn, _, err := sqlmock.New()
70 | assert.Nil(t, err)
71 | return []func(*db.DB) error{
72 | db.Dialector(getGormDialectorForMock(conn)),
73 | }, func() { _ = conn.Close() }
74 | },
75 | },
76 | {
77 | name: "invalid dialect",
78 | err: db.ErrorUndefinedDialect.Error(),
79 | config: &db.Config{ConnectionConfig: db.ConnectionConfig{Dialect: "invalid dialect"}},
80 | options: func() ([]func(*db.DB) error, func()) {
81 | return []func(*db.DB) error{}, func() {}
82 | },
83 | },
84 | {
85 | name: "connect error: no mock",
86 | err: "dial tcp .+:3307: connect: connection refused",
87 | config: getDefaultConfig(),
88 | options: func() ([]func(*db.DB) error, func()) {
89 | return []func(*db.DB) error{}, func() {}
90 | },
91 | },
92 | }
93 |
94 | for _, testCase := range tests {
95 | t.Run(testCase.name, func(t *testing.T) {
96 | options, finish := testCase.options()
97 | defer finish()
98 |
99 | gdb, err := db.NewDb(testCase.config, options...)
100 |
101 | if testCase.err == "" {
102 | assert.Nil(t, err)
103 | assert.NotNil(t, gdb)
104 | } else {
105 | expr, e := regexp.Compile(testCase.err)
106 | assert.Nil(t, e)
107 | assert.NotNil(t, err)
108 | assert.Nil(t, gdb)
109 | assert.Regexp(t, expr, err.Error())
110 | }
111 | })
112 | }
113 | }
114 |
115 | func TestDB_Replica(t *testing.T) {
116 | defConn, _, err := sqlmock.New()
117 | assert.Nil(t, err)
118 | defer defConn.Close()
119 |
120 | replicaConn, replicaMock, err := sqlmock.New()
121 | assert.Nil(t, err)
122 | defer replicaConn.Close()
123 |
124 | sdb, err := db.NewDb(getDefaultConfig(), db.Dialector(getGormDialectorForMock(defConn)), db.GormConfig(&gorm.Config{}))
125 | assert.Nil(t, err)
126 |
127 | err = sdb.Replicas([]gorm.Dialector{getGormDialectorForMock(replicaConn)}, &db.ConnectionPoolConfig{
128 | MaxOpenConnections: 5,
129 | MaxIdleConnections: 5,
130 | ConnectionMaxLifetime: 5 * time.Minute,
131 | })
132 | assert.Nil(t, err)
133 |
134 | model := TestModel{}
135 |
136 | // 1. Test that with replica select query goes to replica
137 | replicaMock.
138 | ExpectQuery(regexp.QuoteMeta("SELECT * FROM `model` WHERE id = ? ORDER BY `model`.`id` LIMIT 1")).
139 | WithArgs("1").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "name1"))
140 |
141 | tx := sdb.Instance(context.TODO()).Where("id = ?", "1").First(&model)
142 | assert.Nil(t, tx.Error)
143 | assert.Equal(t, model.ID, "1")
144 | }
145 |
146 | func TestDB_WarmStorageDB(t *testing.T) {
147 | defConn, _, err := sqlmock.New()
148 | assert.Nil(t, err)
149 | defer defConn.Close()
150 |
151 | warmStorageConn, warmStorageMock, err := sqlmock.New()
152 | assert.Nil(t, err)
153 | defer warmStorageConn.Close()
154 |
155 | newDB, err := db.NewDb(getDefaultConfig(), db.Dialector(getGormDialectorForMock(defConn)), db.GormConfig(&gorm.Config{}))
156 | assert.Nil(t, err)
157 |
158 | err = newDB.WarmStorageDB([]gorm.Dialector{getGormDialectorForMock(warmStorageConn)}, &db.ConnectionPoolConfig{
159 | MaxOpenConnections: 5,
160 | MaxIdleConnections: 5,
161 | ConnectionMaxLifetime: 5 * time.Minute,
162 | })
163 | assert.Nil(t, err)
164 |
165 | model := TestModel{}
166 |
167 | // 1. Test that with warm storage select query goes to warm storage
168 | warmStorageMock.
169 | ExpectQuery(regexp.QuoteMeta("SELECT * FROM `model` WHERE id = ? ORDER BY `model`.`id` LIMIT 1")).
170 | WithArgs("1").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "name1"))
171 |
172 | tx := newDB.Instance(context.TODO()).Clauses(dbresolver.Use(db.WarmStorageDBResolverName)).Where("id = ?", "1").First(&model)
173 | assert.Nil(t, tx.Error)
174 | assert.Equal(t, model.ID, "1")
175 | }
176 |
177 | func TestDb_Alive(t *testing.T) {
178 | conn, _, err := sqlmock.New()
179 | assert.Nil(t, err)
180 | defer conn.Close()
181 |
182 | gdb, err := db.NewDb(getDefaultConfig(), db.Dialector(getGormDialectorForMock(conn)))
183 | assert.Nil(t, err)
184 | assert.NotNil(t, gdb)
185 |
186 | err = gdb.Alive()
187 | assert.Nil(t, err)
188 | }
189 |
190 | func TestDB_Instance(t *testing.T) {
191 | conn, _, err := sqlmock.New()
192 | assert.Nil(t, err)
193 | defer conn.Close()
194 |
195 | gdb1, err := db.NewDb(getDefaultConfig(), db.Dialector(getGormDialectorForMock(conn)))
196 | assert.Nil(t, err)
197 | assert.NotNil(t, gdb1)
198 |
199 | gdb2, err := db.NewDb(getDefaultConfig(), db.Dialector(getGormDialectorForMock(conn)))
200 | assert.Nil(t, err)
201 | assert.NotNil(t, gdb2)
202 |
203 | instance1 := gdb1.Instance(context.TODO())
204 | assert.NotNil(t, instance1)
205 |
206 | instance2 := gdb2.Instance(context.TODO())
207 | assert.NotNil(t, instance2)
208 |
209 | ctx := context.WithValue(context.TODO(), db.ContextKeyDatabase, instance2)
210 | tgdb := gdb1.Instance(ctx)
211 | assert.Equal(t, instance2, tgdb)
212 | }
213 |
214 | func getDefaultConfig() *db.Config {
215 | return &db.Config{
216 | ConnectionPoolConfig: db.ConnectionPoolConfig{
217 | MaxOpenConnections: 5,
218 | MaxIdleConnections: 5,
219 | ConnectionMaxLifetime: 5 * time.Minute,
220 | },
221 | ConnectionConfig: db.ConnectionConfig{
222 | Dialect: "mysql",
223 | Protocol: "tcp",
224 | URL: "localhost",
225 | Port: 3307,
226 | Username: "user",
227 | Password: "pass",
228 | SslMode: "require",
229 | Name: "database",
230 | },
231 | }
232 | }
233 |
234 | func getGormDialectorForMock(conn gorm.ConnPool) gorm.Dialector {
235 | return mysql.New(mysql.Config{Conn: conn, SkipInitializeWithVersion: true})
236 | }
237 |
--------------------------------------------------------------------------------
/pkg/spine/errors.go:
--------------------------------------------------------------------------------
1 | package spine
2 |
3 | import (
4 | "errors"
5 |
6 | "github.com/go-sql-driver/mysql"
7 |
8 | "github.com/lib/pq"
9 | "gorm.io/gorm"
10 | )
11 |
12 | const (
13 | PQCodeUniqueViolation = "unique_violation"
14 | MySqlCodeUniqueViolation = 1062
15 |
16 | errDBError = "db_error"
17 | errNoRowAffected = "no_row_affected"
18 | errRecordNotFound = "record_not_found"
19 | errValidationFailure = "validation_failure"
20 | errUniqueConstraintViolation = "unique_constraint_violation"
21 | )
22 |
23 | var (
24 | DBError = errors.New(errDBError)
25 | NoRowAffected = errors.New(errNoRowAffected)
26 | RecordNotFound = errors.New(errRecordNotFound)
27 | ValidationFailure = errors.New(errValidationFailure)
28 | UniqueConstraintViolation = errors.New(errUniqueConstraintViolation)
29 | )
30 |
31 | // GetDBError accepts db instance and the details
32 | // creates appropriate error based on the type of query result
33 | // if there is no error then returns nil
34 | func GetDBError(db *gorm.DB) error {
35 | if db.Error == nil {
36 | return nil
37 | }
38 |
39 | // check of error is specific to dialect
40 | if de, ok := DialectError(db); ok {
41 | // is the specific error is captured then return it
42 | // else try construct further errors
43 | if err := de.ConstructError(); err != nil {
44 | return err
45 | }
46 | }
47 |
48 | // Construct error based on type of db operation
49 | err := func() error {
50 | switch true {
51 | case errors.Is(db.Error, gorm.ErrRecordNotFound):
52 | return RecordNotFound
53 |
54 | default:
55 | return db.Error
56 | }
57 | }()
58 |
59 | // add specific details of error
60 | return err
61 | }
62 |
63 | // GetValidationError wraps the error and returns instance of ValidationError
64 | // if the provided error is nil then it just returns nil
65 | func GetValidationError(err error) error {
66 | if err != nil {
67 | return err
68 | }
69 |
70 | return nil
71 | }
72 |
73 | // DialectError returns true if the error is from dialect
74 | func DialectError(d *gorm.DB) (IDialectError, bool) {
75 | switch d.Error.(type) {
76 | case *pq.Error:
77 | return pqError{d.Error.(*pq.Error)}, true
78 | case *mysql.MySQLError:
79 | return mysqlError{d.Error.(*mysql.MySQLError)}, true
80 | default:
81 | return nil, false
82 | }
83 | }
84 |
85 | // IDialectError interface to handler dialect related errors
86 | type IDialectError interface {
87 | ConstructError() error
88 | }
89 |
90 | // pqError holds the error occurred by postgres
91 | type pqError struct {
92 | err *pq.Error
93 | }
94 |
95 | // ConstructError will create appropriate error based on dialect
96 | func (pqe pqError) ConstructError() error {
97 | switch pqe.err.Code.Name() {
98 | case PQCodeUniqueViolation:
99 | return pqe.err
100 | default:
101 | return nil
102 | }
103 | }
104 |
105 | type mysqlError struct {
106 | err *mysql.MySQLError
107 | }
108 |
109 | // ConstructError will create appropriate error based on dialect
110 | func (msqle mysqlError) ConstructError() error {
111 | switch msqle.err.Number {
112 | case MySqlCodeUniqueViolation:
113 | return msqle.err
114 |
115 | default:
116 | return nil
117 | }
118 | }
119 |
--------------------------------------------------------------------------------
/pkg/spine/model.go:
--------------------------------------------------------------------------------
1 | package spine
2 |
3 | import (
4 | validation "github.com/go-ozzo/ozzo-validation/v4"
5 | "github.com/razorpay/trino-gateway/pkg/spine/datatype"
6 | )
7 |
8 | const (
9 | AttributeID = "id"
10 | AttributeCreatedAt = "created_at"
11 | AttributeUpdatedAt = "updated_at"
12 | AttributeDeletedAt = "deleted_at"
13 | )
14 |
15 | type Model struct {
16 | ID string `json:"id"`
17 | CreatedAt int64 `json:"created_at"`
18 | UpdatedAt int64 `json:"updated_at"`
19 | }
20 |
21 | type IModel interface {
22 | TableName() string
23 | EntityName() string
24 | GetID() string
25 | Validate() error
26 | SetDefaults() error
27 | }
28 |
29 | // Validate validates base Model.
30 | func (m *Model) Validate() error {
31 | return GetValidationError(
32 | validation.ValidateStruct(
33 | m,
34 | validation.Field(&m.ID, validation.By(datatype.IsRZPID)),
35 | validation.Field(&m.CreatedAt, validation.By(datatype.IsTimestamp)),
36 | validation.Field(&m.UpdatedAt, validation.By(datatype.IsTimestamp)),
37 | ),
38 | )
39 | }
40 |
41 | // GetID gets identifier of entity.
42 | func (m *Model) GetID() string {
43 | return m.ID
44 | }
45 |
46 | // GetCreatedAt gets created time of entity.
47 | func (m *Model) GetCreatedAt() int64 {
48 | return m.CreatedAt
49 | }
50 |
51 | // GetUpdatedAt gets last updated time of entity.
52 | func (m *Model) GetUpdatedAt() int64 {
53 | return m.UpdatedAt
54 | }
55 |
--------------------------------------------------------------------------------
/pkg/spine/repository.go:
--------------------------------------------------------------------------------
1 | package spine
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/razorpay/trino-gateway/pkg/spine/db"
7 | "gorm.io/gorm/clause"
8 | "gorm.io/plugin/dbresolver"
9 |
10 | "gorm.io/gorm"
11 | )
12 |
13 | const updatedAtField = "updated_at"
14 |
15 | type Repo struct {
16 | Db *db.DB
17 | }
18 |
19 | // FindByID fetches the record which matches the ID provided from the entity defined by receiver
20 | // and the result will be loaded into receiver
21 | func (repo Repo) FindByID(ctx context.Context, receiver IModel, id string) error {
22 | q := repo.DBInstance(ctx).Where(AttributeID+" = ?", id).First(receiver)
23 |
24 | return GetDBError(q)
25 | }
26 |
27 | func (repo Repo) FindWithConditionByIDs(
28 | ctx context.Context,
29 | receivers interface{},
30 | condition map[string]interface{},
31 | ids []string,
32 | ) error {
33 |
34 | q := repo.DBInstance(ctx).Where(AttributeID+" in (?)", ids).Where(condition).Find(receivers)
35 |
36 | return GetDBError(q)
37 | }
38 |
39 | // FindByIDs fetches the all the records which matches the IDs provided from the entity defined by receivers
40 | // and the result will be loaded into receivers
41 | func (repo Repo) FindByIDs(ctx context.Context, receivers interface{}, ids []string) error {
42 | q := repo.DBInstance(ctx).Where(AttributeID+" in (?)", ids).Find(receivers)
43 |
44 | return GetDBError(q)
45 | }
46 |
47 | // Create inserts a new record in the entity defined by the receiver
48 | // all data filled in the receiver will inserted
49 | func (repo Repo) Create(ctx context.Context, receiver IModel) error {
50 | if err := receiver.SetDefaults(); err != nil {
51 | return err
52 | }
53 |
54 | if err := receiver.Validate(); err != nil {
55 | return err
56 | }
57 |
58 | q := repo.DBInstance(ctx).Create(receiver)
59 |
60 | return GetDBError(q)
61 | }
62 |
63 | // CreateInBatches insert the value in batches into database
64 | func (repo Repo) CreateInBatches(ctx context.Context, receivers interface{}, batchSize int) error {
65 | q := repo.DBInstance(ctx).CreateInBatches(receivers, batchSize)
66 |
67 | return GetDBError(q)
68 | }
69 |
70 | // Update will update the given receiver model with respect to primary key / id available in it.
71 | // If selective list is non empty, only those fields which are present in the list will be updated.
72 | // Note: When using selectiveList `updated_at` field need not be passed in the list.
73 | func (repo Repo) Update(ctx context.Context, receiver IModel, selectiveList ...string) error {
74 | if len(selectiveList) > 0 {
75 | selectiveList = append(selectiveList, updatedAtField)
76 | }
77 | return repo.updateSelective(ctx, receiver, selectiveList...)
78 | }
79 |
80 | // Delete deletes the given model
81 | // Soft or hard delete of model depends on the models implementation
82 | // if the model composites SoftDeletableModel then it'll be soft deleted
83 | func (repo Repo) Delete(ctx context.Context, receiver IModel) error {
84 | q := repo.DBInstance(ctx).Select(clause.Associations).Delete(receiver)
85 |
86 | return GetDBError(q)
87 | }
88 |
89 | func (repo Repo) ClearAssociations(ctx context.Context, receiver IModel, name string) error {
90 | err := repo.DBInstance(ctx).Model(receiver).Association(name).Clear()
91 |
92 | return err
93 | }
94 |
95 | func (repo Repo) ReplaceAssociations(ctx context.Context, receiver IModel, name string, ass interface{}) error {
96 | err := repo.DBInstance(ctx).Model(receiver).Association(name).Replace(ass)
97 |
98 | return err
99 | }
100 |
101 | // FineMany will fetch multiple records form the entity defined by receiver which matched the condition provided
102 | // note: this wont work for in clause. can be used only for `=` conditions
103 | func (repo Repo) FindMany(
104 | ctx context.Context,
105 | receivers interface{},
106 | condition map[string]interface{}) error {
107 |
108 | q := repo.DBInstance(ctx).Where(condition).Find(receivers)
109 |
110 | return GetDBError(q)
111 | }
112 |
113 | // Preload preload associations with given conditions
114 | // repo.Preload(ctx, "Orders", "state NOT IN (?)", "cancelled").FindMany(ctx, &users)
115 | func (repo Repo) Preload(ctx context.Context, query string, args ...interface{}) *Repo {
116 | return &Repo{
117 | Db: repo.Db.Preload(ctx, query, args),
118 | }
119 | }
120 |
121 | // Transaction will manage the execution inside a transactions
122 | // adds the txn db in the context for downstream use case
123 | func (repo Repo) Transaction(ctx context.Context, fc func(ctx context.Context) error) error {
124 | err := repo.DBInstance(ctx).Transaction(func(tx *gorm.DB) error {
125 | // This will ensure that when db.Instance(context) we return the txn on the context
126 | // & all repo queries are done on this txn. Refer usage in test.
127 | if err := fc(context.WithValue(ctx, db.ContextKeyDatabase, tx)); err != nil {
128 | return err
129 | }
130 |
131 | return GetDBError(tx)
132 | })
133 |
134 | if err == nil {
135 | return nil
136 | }
137 |
138 | // tx.Commit can throw an error which will not be an IError
139 | if iErr, ok := err.(error); ok {
140 | return iErr
141 | }
142 |
143 | // use the default code and wrap err in internal
144 | return err
145 | }
146 |
147 | // IsTransactionActive returns true if a transaction is active
148 | func (repo Repo) IsTransactionActive(ctx context.Context) bool {
149 | _, ok := ctx.Value(db.ContextKeyDatabase).(*gorm.DB)
150 | return ok
151 | }
152 |
153 | // DBInstance returns gorm instance.
154 | // If replicas are specified, for Query, Row callback, will use replicas, unless Write mode specified.
155 | // For Raw callback, statements are considered read-only and will use replicas if the SQL starts with SELECT.
156 | //
157 | func (repo Repo) DBInstance(ctx context.Context) *gorm.DB {
158 | return repo.Db.Instance(ctx)
159 | }
160 |
161 | // WriteDBInstance returns a gorm instance of source/primary db connection.
162 | func (repo Repo) WriteDBInstance(ctx context.Context) *gorm.DB {
163 | return repo.DBInstance(ctx).Clauses(dbresolver.Write)
164 | }
165 |
166 | // WarmStorageDBInstance returns gorm instance of source/primary db connection.
167 | func (repo Repo) WarmStorageDBInstance(ctx context.Context) *gorm.DB {
168 | return repo.DBInstance(ctx).Clauses(dbresolver.Use(db.WarmStorageDBResolverName))
169 | }
170 |
171 | // updateSelective will update the given receiver model with respect to primary key / id available in it.
172 | // If selective list is non empty, only those fields which are present in the list will be updated.
173 | // Note: When using selectiveList `updated_at` field also needs to be explicitly passed in the selectiveList.
174 | func (repo Repo) updateSelective(ctx context.Context, receiver IModel, selectiveList ...string) error {
175 | q := repo.DBInstance(ctx).Model(receiver)
176 |
177 | if len(selectiveList) > 0 {
178 | q = q.Select(selectiveList)
179 | }
180 |
181 | q = q.Updates(receiver)
182 |
183 | if q.RowsAffected == 0 {
184 | return NoRowAffected
185 | }
186 |
187 | return GetDBError(q)
188 | }
189 |
--------------------------------------------------------------------------------
/scripts/compile.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | go mod download
4 |
5 | protoc -I . \
6 | -I "$GOPATH/pkg/mod/github.com/grpc-ecosystem/grpc-gateway/v2@$(go list -m -mod=mod -u github.com/grpc-ecosystem/grpc-gateway/v2 | awk '{print $2}')" \
7 | --openapiv2_opt logtostderr=true \
8 | --openapiv2_opt generate_unbound_methods=true \
9 | --openapiv2_out ./third_party/swaggerui \
10 | --twirp_out=. \
11 | --go_out=. \
12 | rpc/gateway/service.proto
13 |
14 | go mod vendor
15 |
--------------------------------------------------------------------------------
/scripts/coverage.sh:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/razorpay/trino-gateway/3f8ceab54b91b8c2b64a9339fb4c82c4432d6655/scripts/coverage.sh
--------------------------------------------------------------------------------
/scripts/docker.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [[ $@ == "up" ]]; then
4 | docker-compose -f build/docker/dev/docker-compose.yml up -d --build
5 | else
6 | docker-compose -f build/docker/dev/docker-compose.yml down
7 | fi
--------------------------------------------------------------------------------
/scripts/run-example.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Create dummy setup
4 | curl --request POST \
5 | --url http://localhost:8000/twirp/razorpay.gateway.BackendApi/CreateOrUpdateBackend \
6 | --header 'Content-Type: application/json' \
7 | --header 'X-Auth-Key: test123' \
8 | --data '{
9 | "hostname": "docker.for.mac.localhost:37000",
10 | "scheme": "https",
11 | "id": "dev",
12 | "external_url": "docker.for.mac.localhost:37000",
13 | "is_enabled": true,
14 | "uptime_schedule": "* * * * *",
15 | "cluster_load": 0,
16 | "threshold_cluster_load": 0,
17 | "stats_updated_at": "0"
18 | }'
19 |
20 | curl --request POST \
21 | --url http://localhost:8000/twirp/razorpay.gateway.GroupApi/CreateOrUpdateGroup \
22 | --header 'Content-Type: application/json' \
23 | --header 'X-Auth-Key: test123' \
24 | --data '{
25 | "id": "dev",
26 | "backends": ["dev"],
27 | "strategy": "RANDOM",
28 | "last_routed_backend": "dev",
29 | "is_enabled": true
30 | }'
31 |
32 | curl --request POST \
33 | --url http://localhost:8000/twirp/razorpay.gateway.PolicyApi/CreateOrUpdatePolicy \
34 | --header 'Content-Type: application/json' \
35 | --header 'X-Auth-Key: test123' \
36 | --data '{
37 | "id": "dev",
38 | "rule": {
39 | "type": "listening_port",
40 | "value": "8080"
41 | },
42 | "group": "dev",
43 | "fallback_group": "dev",
44 | "is_enabled": true,
45 | "is_auth_delegated": false,
46 | "set_request_source": "localDev"
47 | }'
48 |
49 |
50 |
--------------------------------------------------------------------------------
/scripts/setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | go mod download
4 |
5 | go install github.com/golang/protobuf/protoc-gen-go
6 | go install github.com/gopherjs/gopherjs
7 | go install github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-openapiv2
8 | go install github.com/twitchtv/twirp/protoc-gen-twirp
9 |
--------------------------------------------------------------------------------
/third_party/swaggerui/favicon-16x16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/razorpay/trino-gateway/3f8ceab54b91b8c2b64a9339fb4c82c4432d6655/third_party/swaggerui/favicon-16x16.png
--------------------------------------------------------------------------------
/third_party/swaggerui/favicon-32x32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/razorpay/trino-gateway/3f8ceab54b91b8c2b64a9339fb4c82c4432d6655/third_party/swaggerui/favicon-32x32.png
--------------------------------------------------------------------------------
/third_party/swaggerui/index.css:
--------------------------------------------------------------------------------
1 | html {
2 | box-sizing: border-box;
3 | overflow: -moz-scrollbars-vertical;
4 | overflow-y: scroll;
5 | }
6 |
7 | *,
8 | *:before,
9 | *:after {
10 | box-sizing: inherit;
11 | }
12 |
13 | body {
14 | margin: 0;
15 | background: #fafafa;
16 | }
17 |
--------------------------------------------------------------------------------
/third_party/swaggerui/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Swagger UI
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/third_party/swaggerui/oauth2-redirect.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Swagger UI: OAuth2 Redirect
5 |
6 |
7 |
78 |
79 |
80 |
--------------------------------------------------------------------------------
/third_party/swaggerui/swagger-initializer.js:
--------------------------------------------------------------------------------
1 | window.onload = function() {
2 | //
3 |
4 | // the following lines will be replaced by docker/configurator, when it runs in a docker-container
5 | window.ui = SwaggerUIBundle({
6 | url: "./rpc/gateway/service.swagger.json",
7 | dom_id: '#swagger-ui',
8 | deepLinking: true,
9 | presets: [
10 | SwaggerUIBundle.presets.apis,
11 | SwaggerUIStandalonePreset
12 | ],
13 | plugins: [
14 | SwaggerUIBundle.plugins.DownloadUrl,
15 | HideInfoUrlPartsPlugin,
16 | HideOperationsUntilAuthorizedPlugin
17 | ],
18 | layout: "StandaloneLayout"
19 | });
20 |
21 | //
22 | };
23 |
24 | const HideInfoUrlPartsPlugin = () => {
25 | return {
26 | wrapComponents: {
27 | InfoUrl: () => () => null,
28 | // InfoBasePath: () => () => null, // this hides the `Base Url` part too, if you want that
29 | }
30 | }
31 | }
32 |
33 | const HideOperationsUntilAuthorizedPlugin = function() {
34 | return {
35 | wrapComponents: {
36 | operation: (Ori, system) => (props) => {
37 | const isOperationSecured = !!props.operation.get("security").size
38 | const isOperationAuthorized = props.operation.get("isAuthorized")
39 |
40 | if(!isOperationSecured || isOperationAuthorized) {
41 | return system.React.createElement(Ori, props)
42 | }
43 | return null
44 | }
45 | }
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/tools.go:
--------------------------------------------------------------------------------
1 | //go:build tools
2 | // +build tools
3 |
4 | package tools
5 |
6 | import (
7 | // _ "github.com/elliots/protoc-gen-twirp_swagger"
8 | _ "github.com/golang/protobuf/protoc-gen-go"
9 | _ "github.com/gopherjs/gopherjs"
10 | _ "github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-openapiv2"
11 | _ "github.com/twitchtv/twirp/protoc-gen-twirp"
12 | )
13 |
--------------------------------------------------------------------------------
/web/frontend/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/razorpay/trino-gateway/3f8ceab54b91b8c2b64a9339fb4c82c4432d6655/web/frontend/favicon.ico
--------------------------------------------------------------------------------
/web/frontend/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------