├── .cursorignore ├── .devcontainer ├── devcontainer.json └── docker-compose.extend.yml ├── .dockerignore ├── .editorconfig ├── .github ├── dependabot.yml └── workflows │ ├── ci.yml │ └── genesis.yml ├── .gitignore ├── .vscode ├── extensions.json ├── launch.json └── tasks.json ├── Makefile ├── README.md ├── assets └── design_doc.md ├── build └── docker │ ├── dev │ ├── Dockerfile │ ├── docker-compose.yml │ ├── entrypoint.sh │ └── mysqlconf │ │ └── mysql.cnf │ └── prod │ ├── Dockerfile │ ├── Dockerfile.razorpay │ ├── entrypoint.sh │ └── probe.sh ├── cmd ├── gateway │ └── main.go └── migration │ └── main.go ├── config ├── default.toml ├── dev-docker.toml ├── prod-dev.toml └── prod.toml ├── go.mod ├── go.sum ├── internal ├── boot │ └── boot.go ├── config │ └── config.go ├── constants │ └── contextkeys │ │ └── contextkeys.go ├── frontend │ ├── README.md │ ├── components │ │ ├── adminview.go │ │ ├── dashboardview.go │ │ ├── pageview.go │ │ ├── querylistview.go │ │ ├── queryview.go │ │ └── tabview.go │ ├── core │ │ └── core.go │ ├── dispatcher │ │ └── dispatcher.go │ ├── main.go │ └── server │ │ └── main.go ├── gatewayserver │ ├── backendApi │ │ ├── core.go │ │ ├── server.go │ │ └── validation.go │ ├── database │ │ ├── dbRepo │ │ │ └── dbRepo.go │ │ └── migrations │ │ │ ├── 20210805195304_bootstrap.go │ │ │ ├── 20211203205304_alter_backends.go │ │ │ ├── 20220107205304_increase_query_text_size.go │ │ │ ├── 20240524205304_add_auth_delegation.go │ │ │ └── 20240525205304_add_set_source.go │ ├── groupApi │ │ ├── core.go │ │ ├── server.go │ │ └── validation.go │ ├── healthApi │ │ ├── core.go │ │ └── server.go │ ├── hooks │ │ ├── auth.go │ │ ├── ctx.go │ │ ├── metric.go │ │ └── requestid.go │ ├── metrics │ │ └── metrics.go │ ├── models │ │ ├── backend.go │ │ ├── group.go │ │ ├── policy.go │ │ └── query.go │ ├── policyApi │ │ ├── core.go │ │ ├── core_test.go │ │ ├── server.go │ │ └── validation.go │ ├── queryApi │ │ ├── core.go │ │ ├── server.go │ │ └── validation.go │ └── repo │ │ ├── backend.go │ │ ├── group.go │ │ ├── policy.go │ │ └── query.go ├── monitor │ ├── core.go │ ├── metric.go │ ├── monitor.go │ └── trino.go ├── provider │ └── logger.go ├── router │ ├── auth.go │ ├── auth_test.go │ ├── metric.go │ ├── request.go │ ├── request_test.go │ ├── request_type.go │ ├── response.go │ ├── router.go │ └── trinoheaders │ │ ├── trino.go │ │ └── trino_test.go └── utils │ ├── utils.go │ └── utils_test.go ├── pkg ├── config │ ├── config.go │ ├── config_test.go │ └── testdata │ │ └── default.toml ├── fetcher │ ├── fetcher.go │ └── manager.go ├── logger │ ├── entry.go │ ├── entry_test.go │ ├── logger.go │ └── logger_test.go └── spine │ ├── datatype │ └── validation.go │ ├── db │ ├── db.go │ └── db_test.go │ ├── errors.go │ ├── model.go │ └── repository.go ├── rpc └── gateway │ └── service.proto ├── scripts ├── compile.sh ├── coverage.sh ├── docker.sh ├── run-example.sh └── setup.sh ├── third_party └── swaggerui │ ├── favicon-16x16.png │ ├── favicon-32x32.png │ ├── index.css │ ├── index.html │ ├── oauth2-redirect.html │ ├── swagger-initializer.js │ ├── swagger-ui-bundle.js │ ├── swagger-ui-bundle.js.map │ ├── swagger-ui-es-bundle-core.js │ ├── swagger-ui-es-bundle-core.js.map │ ├── swagger-ui-es-bundle.js │ ├── swagger-ui-es-bundle.js.map │ ├── swagger-ui-standalone-preset.js │ ├── swagger-ui-standalone-preset.js.map │ ├── swagger-ui.css │ ├── swagger-ui.css.map │ ├── swagger-ui.js │ └── swagger-ui.js.map ├── tools.go └── web └── frontend ├── favicon.ico └── index.html /.cursorignore: -------------------------------------------------------------------------------- 1 | # Distribution and Environment 2 | dist/* 3 | build/* 4 | venv/* 5 | env/* 6 | *.env 7 | .env.* 8 | virtualenv/* 9 | .python-version 10 | .ruby-version 11 | .node-version 12 | 13 | # Logs and Temporary Files 14 | *.log 15 | *.tsv 16 | *.csv 17 | *.txt 18 | tmp/* 19 | temp/* 20 | .tmp/* 21 | *.temp 22 | *.cache 23 | .cache/* 24 | logs/* 25 | 26 | # Sensitive Data 27 | *.json 28 | *.xml 29 | *.yml 30 | *.yaml 31 | *.properties 32 | properties.json 33 | *.sqlite 34 | *.sqlite3 35 | *.dbsql 36 | secrets.* 37 | *secret* 38 | *password* 39 | *credential* 40 | .npmrc 41 | .yarnrc 42 | .aws/* 43 | .config/* 44 | 45 | # Credentials and Keys 46 | *.pem 47 | *.ppk 48 | *.key 49 | *.pub 50 | *.p12 51 | *.pfx 52 | *.htpasswd 53 | *.keystore 54 | *.jks 55 | *.truststore 56 | *.cer 57 | id_rsa* 58 | known_hosts 59 | authorized_keys 60 | .ssh/* 61 | .gnupg/* 62 | .pgpass 63 | 64 | # Config Files 65 | *.conf 66 | *.toml 67 | *.ini 68 | .env.local 69 | .env.development 70 | .env.test 71 | .env.production 72 | config/* 73 | 74 | # Documentation and Notes 75 | *.md 76 | *.mdx 77 | *.rst 78 | *.txt 79 | docs/* 80 | README* 81 | CHANGELOG* 82 | LICENSE* 83 | CONTRIBUTING* 84 | 85 | # Database Files 86 | *.sql 87 | *.db 88 | *.dmp 89 | *.dump 90 | *.backup 91 | *.restore 92 | *.mdb 93 | *.accdb 94 | *.realm* 95 | 96 | # Backup and Archive Files 97 | *.bak 98 | *.backup 99 | *.swp 100 | *.swo 101 | *.swn 102 | *~ 103 | *.old 104 | *.orig 105 | *.archive 106 | *.gz 107 | *.zip 108 | *.tar 109 | *.rar 110 | *.7z 111 | 112 | # Compiled and Binary Files 113 | *.pyc 114 | *.pyo 115 | **/__pycache__/** 116 | *.class 117 | *.jar 118 | *.war 119 | *.ear 120 | *.dll 121 | *.exe 122 | *.so 123 | *.dylib 124 | *.bin 125 | *.obj 126 | 127 | # IDE and Editor Files 128 | .idea/* 129 | *.iml 130 | .vscode/* 131 | .project 132 | .classpath 133 | .settings/* 134 | *.sublime-* 135 | .atom/* 136 | .eclipse/* 137 | *.code-workspace 138 | .history/* 139 | 140 | # Build and Dependency Directories 141 | node_modules/* 142 | bower_components/* 143 | vendor/* 144 | packages/* 145 | jspm_packages/* 146 | .gradle/* 147 | target/* 148 | out/* 149 | 150 | # Testing and Coverage Files 151 | coverage/* 152 | .coverage 153 | htmlcov/* 154 | .pytest_cache/* 155 | .tox/* 156 | junit.xml 157 | test-results/* 158 | 159 | # Mobile Development 160 | *.apk 161 | *.aab 162 | *.ipa 163 | *.xcarchive 164 | *.provisionprofile 165 | google-services.json 166 | GoogleService-Info.plist 167 | 168 | # Certificate and Security Files 169 | *.crt 170 | *.csr 171 | *.ovpn 172 | *.p7b 173 | *.p7s 174 | *.pfx 175 | *.spc 176 | *.stl 177 | *.pem.crt 178 | ssl/* 179 | 180 | # Container and Infrastructure 181 | *.tfstate 182 | *.tfstate.backup 183 | .terraform/* 184 | .vagrant/* 185 | docker-compose.override.yml 186 | kubernetes/* 187 | 188 | # Design and Media Files (often large and binary) 189 | *.psd 190 | *.ai 191 | *.sketch 192 | *.fig 193 | *.xd 194 | assets/raw/* 195 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the 2 | // README at: https://github.com/devcontainers/templates/tree/main/src/docker-outside-of-docker-compose 3 | { 4 | "name": "Existing Docker Compose (Extend)", 5 | 6 | // Update the 'dockerComposeFile' list if you have more compose files or use different names. 7 | // The .devcontainer/docker-compose.yml file contains any overrides you need/want to make. 8 | "dockerComposeFile": [ 9 | "../build/docker/dev/docker-compose.yml", 10 | "docker-compose.extend.yml" 11 | ], 12 | 13 | // The 'service' property is the name of the service for the container that VS Code should 14 | // use. Update this value and .devcontainer/docker-compose.yml to the real service name. 15 | "service": "trino_gateway", 16 | 17 | // The optional 'workspaceFolder' property is the path VS Code should open by default when 18 | // connected. This is typically a file mount in .devcontainer/docker-compose.yml 19 | "workspaceFolder": "/app", 20 | // Use 'forwardPorts' to make a list of ports inside the container available locally. 21 | // "forwardPorts": [], 22 | 23 | // Use 'postCreateCommand' to run commands after the container is created. 24 | // "postCreateCommand": "docker --version", 25 | 26 | // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. 27 | "remoteUser": "root", 28 | "customizations": { 29 | "vscode": { 30 | "extensions": [ 31 | "golang.go", 32 | "zxh404.vscode-proto3", 33 | "be5invis.toml" 34 | ] 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /.devcontainer/docker-compose.extend.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | services: 3 | trino_gateway: 4 | # [Optional] Required for ptrace-based debuggers like C++, Go, and Rust 5 | cap_add: 6 | - SYS_PTRACE 7 | security_opt: 8 | - seccomp:unconfined 9 | # Overrides default command so things don't shut down after the process ends. 10 | entrypoint: /bin/sh -c "while sleep 1000; do :; done" 11 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .git/ 2 | 3 | Dockerfile* 4 | 5 | .dockerignore 6 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | ; http://editorconfig.org/ 2 | 3 | root = true 4 | 5 | [*] 6 | end_of_line = lf 7 | insert_final_newline = true 8 | charset = utf-8 9 | trim_trailing_whitespace = true 10 | 11 | [*.go] 12 | indent_style = tab 13 | indent_size = 4 14 | 15 | [*.md] 16 | trim_trailing_whitespace = false 17 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: gomod 4 | directory: "/" 5 | schedule: 6 | interval: daily 7 | time: "04:00" 8 | timezone: Asia/Calcutta 9 | 10 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | build-public-image: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - 13 | name: Set up QEMU 14 | uses: docker/setup-qemu-action@v1 15 | - 16 | name: Set up Docker Buildx 17 | uses: docker/setup-buildx-action@v2 18 | - 19 | name: Login to DockerHub 20 | uses: docker/login-action@v2 21 | with: 22 | username: ${{ secrets.PUBLIC_DOCKER_USERNAME }} 23 | password: ${{ secrets.PUBLIC_DOCKER_PASSWORD }} 24 | - 25 | name: Build and push 26 | id: docker_build 27 | uses: docker/build-push-action@v2 28 | with: 29 | file: build/docker/prod/Dockerfile 30 | push: true 31 | tags: razorpay/presto_gateway:${{ github.sha }} 32 | build-args: GIT_COMMIT_HASH=${{ github.sha }} 33 | - 34 | name: Image digest 35 | run: echo ${{ steps.docker_build.outputs.digest }} 36 | 37 | # Rzp image is built from datahub 38 | -------------------------------------------------------------------------------- /.github/workflows/genesis.yml: -------------------------------------------------------------------------------- 1 | name: Quality Checks 2 | on: 3 | schedule: 4 | - cron: "0 17 * * *" 5 | jobs: 6 | Analysis: 7 | uses: razorpay/genesis/.github/workflows/quality-checks.yml@master 8 | secrets: inherit 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # JetBrains project folders. 2 | .idea/ 3 | 4 | # VSCode test folder 5 | .vscode-test/* 6 | 7 | # Desktop Services Store - Mac. 8 | .DS_Store 9 | 10 | # Vendor modules. 11 | vendor/* 12 | 13 | # App Binaries. 14 | bin/* 15 | 16 | # Documentation. Use `make docs` to generate 17 | docs/* 18 | 19 | # Ignores all mock/* directories in project. 20 | # These are created by mockgen tool and only servers unit tests when they run. 21 | mock/ 22 | 23 | # Dont ignore any .gitkeep files, please. 24 | !*.gitkeep 25 | 26 | .tmp 27 | 28 | # This needs to be sourced from the proto3 repo. 29 | proto 30 | 31 | # Generated protobuf files. 32 | rpc/*/*.pb.go 33 | rpc/*/*.twirp.go 34 | 35 | # Generated or local configuration files. 36 | config/dev.* 37 | 38 | # docker-compose .dev file in deployments. 39 | # deployments/docker-compose.dev.yml 40 | 41 | 42 | # Gopherjs compiled files 43 | web/frontend/js/* 44 | 45 | # Generated swagger defs 46 | third_party/swaggerui/rpc/gateway/service.swagger.json 47 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": ["golang.go", "ms-vscode-remote.remote-containers", "be5invis.toml", "zxh404.vscode-proto3", "ms-azuretools.vscode-docker"] 3 | } 4 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "type": "chrome", 9 | "request": "launch", 10 | "name": "Launch gateway GUI in Chrome", 11 | "url": "http://localhost:8000", 12 | "webRoot": "${workspaceFolder}" 13 | } 14 | ] 15 | } 16 | -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | // See https://go.microsoft.com/fwlink/?LinkId=733558 3 | // for the documentation about the tasks.json format 4 | "version": "2.0.0", 5 | "tasks": [ 6 | { 7 | "label": "serve on local", 8 | "type": "shell", 9 | "command": "go run ./cmd/gateway | jq", 10 | "problemMatcher": [] 11 | }, 12 | { 13 | "label": "serve on local without jq", 14 | "type": "shell", 15 | "command": "go run ./cmd/gateway", 16 | "problemMatcher": [] 17 | }, 18 | { 19 | "label": "send query submission request to localhost:8080", 20 | "type": "shell", 21 | "command": "curl -X POST http://localhost:8080/v1/statement -H 'X-Trino-User: dev' -d 'SELECT 1' ", 22 | "problemMatcher": [] 23 | }, 24 | { 25 | "label": "Run local-dev setup example", 26 | "type": "shell", 27 | "command": "source ./scripts/run-example.sh", 28 | "problemMatcher": [] 29 | } 30 | ] 31 | } 32 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SCRIPT_DIR := "./scripts" 2 | 3 | BUILD_OUT_DIR := "bin/" 4 | 5 | PWD = $(shell pwd) 6 | 7 | .PHONY: setup 8 | setup: 9 | $(SCRIPT_DIR)/setup.sh 10 | 11 | .PHONY: dev-setup 12 | dev-setup: setup config/dev.toml 13 | 14 | 15 | config/dev.toml: 16 | touch $(PWD)/config/dev.toml 17 | 18 | .PHONY: build 19 | build: 20 | $(SCRIPT_DIR)/compile.sh 21 | 22 | .PHONY: build-frontend 23 | build-frontend: web/frontend/js/frontend.js 24 | 25 | web/frontend/js/frontend.js: 26 | echo "Compiling frontend" 27 | gopherjs build ./internal/frontend --output "./web/frontend/js/frontend.js" --minify --verbose 28 | 29 | # .PHONY: dev-build 30 | # dev-build: 31 | # $(SCRIPT_DIR)/dev.sh 32 | 33 | .PHONY: dev-docker-up 34 | dev-docker-up: 35 | $(SCRIPT_DIR)/docker.sh up 36 | 37 | .PHONY: dev-docker-down 38 | dev-docker-down: 39 | $(SCRIPT_DIR)/docker.sh down 40 | 41 | .PHONY: dev-docker-run-example ## Runs bundled example in dev docker env 42 | dev-docker-run-example: 43 | $(SCRIPT_DIR)/run-example.sh 44 | 45 | .PHONY: dev-migration 46 | dev-migration: 47 | go build ./cmd/migration/main.go -o migration.go 48 | ./migration.go up 49 | 50 | .PHONY: test-integration 51 | test-integration: 52 | go test -tags=integration ./test/it -v -count=1 53 | 54 | .PHONY: test-unit 55 | test-unit: 56 | go test 57 | -------------------------------------------------------------------------------- /assets/design_doc.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/razorpay/trino-gateway/3f8ceab54b91b8c2b64a9339fb4c82c4432d6655/assets/design_doc.md -------------------------------------------------------------------------------- /build/docker/dev/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.22.3-alpine3.20 2 | 3 | WORKDIR /app 4 | 5 | 6 | RUN apk update \ 7 | && apk add --no-cache bash make protobuf protobuf-dev git gzip curl build-base 8 | 9 | # COPY ./ /app 10 | # RUN go mod download 11 | 12 | RUN go install github.com/githubnemo/CompileDaemon@v1.4.0 13 | 14 | COPY ./go.mod /app/go.mod 15 | COPY ./go.sum /app/go.sum 16 | 17 | # RUN go mod vendor 18 | RUN go mod download 19 | 20 | ENTRYPOINT /app/build/docker/dev/entrypoint.sh 21 | -------------------------------------------------------------------------------- /build/docker/dev/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | services: 3 | trino_gateway: 4 | build: 5 | context: ./../../.. 6 | dockerfile: build/docker/dev/Dockerfile 7 | image: trino_gateway 8 | # image: utkarshsaxena/utk_trino_gateway:v6 9 | container_name: trino_gateway 10 | volumes: 11 | - ./../../..:/app 12 | environment: 13 | APP_ENV: dev-docker 14 | TRINO-GATEWAY_AUTH_ROUTER_AUTHENTICATE: true 15 | entrypoint: /app/build/docker/dev/entrypoint.sh 16 | # entrypoint: ["tail", "-f", "/dev/null"] 17 | expose: 18 | - "8000" 19 | - "8001" 20 | - "8002" 21 | - "8080" 22 | - "8081" 23 | ports: 24 | - 28000:8000 25 | - 28001:8001 26 | - 28002:8002 27 | - 28080:8080 28 | - 28081:8081 29 | networks: 30 | - default 31 | links: 32 | - trino_gateway_mysql 33 | trino_gateway_mysql: 34 | image: mysql:8.0.28-oracle 35 | container_name: trino_gateway_mysql 36 | volumes: 37 | - ./mysqlconf:/etc/mysql/conf.d 38 | ports: 39 | - 33306:3306 40 | environment: 41 | MYSQL_ROOT_PASSWORD: root123 42 | MYSQL_DATABASE: trino-gateway 43 | networks: 44 | - default 45 | # networks: 46 | # trino: 47 | # external: true 48 | -------------------------------------------------------------------------------- /build/docker/dev/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | initialize() { 4 | # Init app secrets + envvars 5 | echo "Syncing app deps, if this takes time, update deps in the built image" 6 | make setup build 7 | } 8 | 9 | check_db_connection() { 10 | # Wait till db is available 11 | connected=0 12 | counter=0 13 | 14 | echo "Wait 60 seconds for connection to MySQL" 15 | while [[ ${counter} -lt 60 ]]; do 16 | { 17 | echo "Connecting to MySQL" && go run ./cmd/migration/main.go version && 18 | connected=1 19 | 20 | } || { 21 | let counter=$counter+3 22 | sleep 3 23 | } 24 | if [[ ${connected} -eq 1 ]]; then 25 | echo "Connected" 26 | break; 27 | fi 28 | done 29 | 30 | if [[ ${connected} -eq 0 ]]; then 31 | echo "MySQL connection failed." 32 | exit; 33 | fi 34 | } 35 | 36 | db_migrations() { 37 | go run ./cmd/migration/main.go up 38 | } 39 | 40 | initialize 41 | check_db_connection 42 | # run db migrations 43 | db_migrations 44 | 45 | # run app 46 | 47 | # CompileDaemon -polling-interval=10 -exclude-dir=.git -exclude-dir=vendor --build="gopherjs build ./internal/frontend/main --output "./web/frontend/js/frontend.js" --verbose && go build cmd/gateway/main.go -o gateway" --command=./gateway 48 | go run ./cmd/gateway/main.go 49 | # tail -f /dev/null 50 | -------------------------------------------------------------------------------- /build/docker/dev/mysqlconf/mysql.cnf: -------------------------------------------------------------------------------- 1 | [mysqld] 2 | skip-host-cache 3 | skip-name-resolve 4 | general_log=1 5 | general_log_file=/var/lib/mysql/general.log 6 | -------------------------------------------------------------------------------- /build/docker/prod/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.22.3-alpine3.20 2 | 3 | ARG GIT_COMMIT_HASH 4 | ENV GIT_COMMIT_HASH=${GIT_COMMIT_HASH} 5 | ENV TRINO-GATEWAY_APP_GITCOMMITHASH=${GIT_COMMIT_HASH} 6 | 7 | WORKDIR /app 8 | 9 | 10 | RUN apk update \ 11 | && apk add --no-cache bash make protobuf protobuf-dev git gzip curl build-base 12 | 13 | COPY ./ /app 14 | 15 | RUN go mod download \ 16 | && make setup build 17 | # RUN go mod vendor 18 | 19 | COPY ./build/docker/prod/probe.sh /app/probe.sh 20 | 21 | ENTRYPOINT /app/build/docker/prod/entrypoint.sh 22 | -------------------------------------------------------------------------------- /build/docker/prod/Dockerfile.razorpay: -------------------------------------------------------------------------------- 1 | ## Dockerfile used for Rzp deployment 2 | 3 | FROM c.rzp.io/razorpay/onggi-multi-arch:rzp-golden-image-base-golang-1.22 4 | # TODO: c.rzp.io/razorpay/rzp-docker-image-inventory-multi-arch:rzp-golden-image-base-golang-1.22 5 | 6 | ARG GIT_COMMIT_HASH 7 | ENV GIT_COMMIT_HASH=${GIT_COMMIT_HASH} 8 | ENV TRINO-GATEWAY_APP_GITCOMMITHASH=${GIT_COMMIT_HASH} 9 | 10 | WORKDIR /app 11 | 12 | 13 | RUN apk update \ 14 | && apk add --no-cache bash make protobuf protobuf-dev git gzip curl build-base 15 | 16 | COPY ./ /app 17 | 18 | RUN go mod download \ 19 | && make setup build 20 | # RUN go mod vendor 21 | 22 | COPY ./build/docker/prod/probe.sh /app/probe.sh 23 | 24 | ENTRYPOINT /app/build/docker/prod/entrypoint.sh 25 | -------------------------------------------------------------------------------- /build/docker/prod/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | initialize() { 4 | # Init app secrets + envvars 5 | echo "Initializing app" 6 | } 7 | 8 | check_db_connection() { 9 | # Wait till db is available 10 | connected=0 11 | counter=0 12 | 13 | echo "Wait 60 seconds for connection to MySQL" 14 | while [[ ${counter} -lt 60 ]]; do 15 | { 16 | echo "Connecting to MySQL" && go run ./cmd/migration/main.go version && 17 | connected=1 18 | 19 | } || { 20 | let counter=$counter+3 21 | sleep 3 22 | } 23 | if [[ ${connected} -eq 1 ]]; then 24 | echo "Connected" 25 | break; 26 | fi 27 | done 28 | 29 | if [[ ${connected} -eq 0 ]]; then 30 | echo "MySQL connection failed." 31 | exit; 32 | fi 33 | } 34 | 35 | db_migrations() { 36 | go run ./cmd/migration/main.go up 37 | } 38 | 39 | initialize 40 | check_db_connection 41 | # run db migrations 42 | db_migrations 43 | 44 | # run app 45 | 46 | # CompileDaemon -polling-interval=10 -exclude-dir=.git -exclude-dir=vendor --build="gopherjs build ./internal/frontend/main --output "./web/frontend/js/frontend.js" --verbose && go build cmd/gateway/main.go -o gateway" --command=./gateway 47 | go run ./cmd/gateway/main.go 48 | # tail -f /dev/null 49 | -------------------------------------------------------------------------------- /build/docker/prod/probe.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #validating the status code of health check api and returning 1 in case of non 2xx status code 3 | if [ "$(curl -s -o /dev/null -H 'Content-Type: application/json' -d '{"service": ""}' -w '%{http_code}' http://localhost:8000/twirp/razorpay.gateway.HealthCheckAPI/Check)" != 200 ]; 4 | then 5 | exit 1; 6 | fi 7 | -------------------------------------------------------------------------------- /cmd/migration/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "os" 9 | 10 | "github.com/pressly/goose/v3" 11 | "github.com/razorpay/trino-gateway/internal/boot" 12 | _ "github.com/razorpay/trino-gateway/internal/gatewayserver/database/migrations" 13 | ) 14 | 15 | var ( 16 | flags = flag.NewFlagSet("goose", flag.ExitOnError) 17 | dir = flags.String("dir", "internal/gatewayserver/database/migrations", "Directory with migration files") 18 | verbose = flags.Bool("v", false, "Enable verbose mode") 19 | ) 20 | 21 | func main() { 22 | environ := boot.GetEnv() 23 | 24 | flags.Usage = usage 25 | if err := flags.Parse(os.Args[1:]); err != nil { 26 | log.Fatalf("error parsing flags: %v", err) 27 | } 28 | args := flags.Args() 29 | if *verbose { 30 | goose.SetVerbose(true) 31 | } 32 | 33 | // I.e. no command provided, hence print usage and return. 34 | if len(args) < 1 { 35 | flags.Usage() 36 | return 37 | } 38 | 39 | // Prepares command and arguments for goose's run. 40 | command := args[0] 41 | arguments := []string{} 42 | if len(args) > 1 { 43 | arguments = append(arguments, args[1:]...) 44 | } 45 | 46 | // If command is create or fix, no need to connect to db and hence the 47 | // specific case handling. 48 | switch command { 49 | case "create": 50 | if err := goose.Run("create", nil, *dir, arguments...); err != nil { 51 | log.Fatalf("failed to run command: %v", err) 52 | } 53 | return 54 | case "fix": 55 | if err := goose.Run("fix", nil, *dir); err != nil { 56 | log.Fatalf("failed to run command: %v", err) 57 | } 58 | return 59 | } 60 | 61 | // For other commands boot application (hence getting db and config ready). 62 | // Read application's dialect and get sqldb instance. 63 | if err := boot.InitMigration(context.Background(), environ); err != nil { 64 | log.Fatalf("failed to run command: %v", err) 65 | } 66 | 67 | dialect := boot.Config.Db.Dialect 68 | if err := goose.SetDialect(dialect); err != nil { 69 | log.Fatalf("failed to run command: %v", err) 70 | } 71 | sqldb, err := boot.DB.Instance(context.Background()).DB() 72 | if err != nil { 73 | log.Fatalf("failed to run command: %v", err) 74 | } 75 | 76 | // Finally, executes the goose's command. 77 | if err := goose.Run(command, sqldb, *dir, arguments...); err != nil { 78 | log.Fatalf("failed to run command: %v", err) 79 | } 80 | 81 | } 82 | 83 | func usage() { 84 | flags.PrintDefaults() 85 | fmt.Println(usageCommands) 86 | } 87 | 88 | var ( 89 | usageCommands = ` 90 | Commands: 91 | up Migrate the DB to the most recent version available 92 | up-to VERSION Migrate the DB to a specific VERSION 93 | down Roll back the version by 1 94 | down-to VERSION Roll back to a specific VERSION 95 | redo Re-run the latest migration 96 | reset Roll back all migrations 97 | status Dump the migration status for the current DB 98 | version Print the current version of the database 99 | create NAME Creates new migration file with the current timestamp 100 | fix Apply sequential ordering to migrations 101 | ` 102 | ) 103 | -------------------------------------------------------------------------------- /config/default.toml: -------------------------------------------------------------------------------- 1 | [app] 2 | env = "default" 3 | gitCommitHash = "nil" 4 | logLevel = "info" 5 | metricsPort = 8002 6 | # gui & twirp app need to be on same port for now, check frontend README for more details 7 | # guiPort = 8000 8 | port = 8000 9 | serviceExternalHostname = "localhost:8080" 10 | serviceHostname = "$$internalHost" 11 | serviceName = "trino-gateway" 12 | shutdownDelay = 2 13 | shutdownTimeout = 5 14 | 15 | [db] 16 | [db.ConnectionConfig] 17 | dialect = "mysql" 18 | protocol = "tcp" 19 | url = "localhost" 20 | port = 33306 21 | username = "root" 22 | password = "root123" 23 | sslMode = "require" 24 | name = "trino-gateway" 25 | [db.ConnectionPoolConfig] 26 | maxOpenConnections = 5 27 | maxIdleConnections = 5 28 | connectionMaxLifetime = 0 29 | 30 | [auth] 31 | token = "test123" 32 | tokenHeaderKey = "X-Auth-Key" 33 | [auth.router.delegatedAuth] 34 | validationProviderURL = "localhost:28001" 35 | validationProviderToken = "test123" 36 | cacheTTLMinutes = "10m" 37 | 38 | 39 | 40 | 41 | [gateway] 42 | ports = [8080, 8081] 43 | defaultRoutingGroup = "adhoc" 44 | # empty will mean 0.0.0.0 which is required only if running inside docker container, set to `localhost` otherwise 45 | network = "" 46 | 47 | [monitor] 48 | interval = "10m" 49 | statsValiditySecs = 0 50 | healthCheckSql = "SELECT 1" 51 | [monitor.trino] 52 | user = "trino-gateway" 53 | password = "" 54 | -------------------------------------------------------------------------------- /config/dev-docker.toml: -------------------------------------------------------------------------------- 1 | [app] 2 | gitCommitHash = "dev-docker" 3 | logLevel = "debug" 4 | 5 | [db] 6 | [db.ConnectionConfig] 7 | url = "trino_gateway_mysql" 8 | port = 3306 9 | 10 | [gateway] 11 | defaultRoutingGroup = "dev" 12 | -------------------------------------------------------------------------------- /config/prod-dev.toml: -------------------------------------------------------------------------------- 1 | [app] 2 | logLevel = "info" 3 | [monitor] 4 | interval = "20s" 5 | healthCheckSql = "SHOW SCHEMAS FROM hive" 6 | [gateway] 7 | ports = [8080, 8081, 8082, 8083] 8 | -------------------------------------------------------------------------------- /config/prod.toml: -------------------------------------------------------------------------------- 1 | [app] 2 | logLevel = "warn" 3 | [monitor] 4 | interval = "20s" 5 | healthCheckSql = "SHOW SCHEMAS FROM hive" 6 | [gateway] 7 | ports = [8080, 8081, 8082, 8083, 8084, 8085, 8086, 8087, 8088] 8 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/razorpay/trino-gateway 2 | 3 | go 1.22 4 | 5 | require ( 6 | github.com/DATA-DOG/go-sqlmock v1.5.2 7 | github.com/NYTimes/gziphandler v1.1.1 8 | github.com/dlmiddlecote/sqlstats v1.0.2 9 | github.com/fatih/structs v1.1.0 10 | github.com/go-co-op/gocron v1.35.0 // v1.35.0+ are broken for v1 11 | github.com/go-ozzo/ozzo-validation v3.6.0+incompatible 12 | github.com/go-ozzo/ozzo-validation/v4 v4.3.0 13 | github.com/go-sql-driver/mysql v1.8.1 14 | github.com/gobuffalo/nulls v0.4.2 15 | github.com/golang/protobuf v1.5.4 16 | github.com/gopherjs/gopherjs v1.19.0-beta1 17 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 18 | github.com/hexops/vecty v0.6.0 19 | github.com/lib/pq v1.10.9 20 | github.com/pressly/goose/v3 v3.20.0 21 | github.com/prometheus/client_golang v1.19.1 22 | github.com/robfig/cron/v3 v3.0.1 23 | github.com/rs/xid v1.5.0 24 | github.com/spf13/viper v1.18.2 25 | github.com/stretchr/testify v1.9.0 26 | github.com/trinodb/trino-go-client v0.315.0 27 | github.com/twitchtv/twirp v8.1.3+incompatible 28 | go.uber.org/zap v1.27.0 29 | google.golang.org/protobuf v1.34.1 30 | gorm.io/driver/mysql v1.1.2 31 | gorm.io/driver/postgres v1.1.2 32 | gorm.io/gorm v1.21.16 33 | gorm.io/plugin/dbresolver v1.1.0 34 | ) 35 | 36 | require ( 37 | filippo.io/edwards25519 v1.1.0 // indirect 38 | github.com/beorn7/perks v1.0.1 // indirect 39 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 40 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 41 | github.com/evanw/esbuild v0.21.4 // indirect 42 | github.com/fsnotify/fsnotify v1.7.0 // indirect 43 | github.com/gofrs/uuid v4.4.0+incompatible // indirect 44 | github.com/google/uuid v1.6.0 // indirect 45 | github.com/hashicorp/go-uuid v1.0.3 // indirect 46 | github.com/hashicorp/hcl v1.0.0 // indirect 47 | github.com/inconshreveable/mousetrap v1.1.0 // indirect 48 | github.com/jackc/chunkreader/v2 v2.0.1 // indirect 49 | github.com/jackc/pgconn v1.14.3 // indirect 50 | github.com/jackc/pgio v1.0.0 // indirect 51 | github.com/jackc/pgpassfile v1.0.0 // indirect 52 | github.com/jackc/pgproto3/v2 v2.3.3 // indirect 53 | github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect 54 | github.com/jackc/pgtype v1.14.3 // indirect 55 | github.com/jackc/pgx/v4 v4.18.3 // indirect 56 | github.com/jcmturner/gofork v1.7.6 // indirect 57 | github.com/jinzhu/inflection v1.0.0 // indirect 58 | github.com/jinzhu/now v1.1.5 // indirect 59 | github.com/magiconair/properties v1.8.7 // indirect 60 | github.com/mfridman/interpolate v0.0.2 // indirect 61 | github.com/mitchellh/mapstructure v1.5.0 // indirect 62 | github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86 // indirect 63 | github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c // indirect 64 | github.com/pelletier/go-toml/v2 v2.2.2 // indirect 65 | github.com/pkg/errors v0.9.1 // indirect 66 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 67 | github.com/prometheus/client_model v0.6.1 // indirect 68 | github.com/prometheus/common v0.53.0 // indirect 69 | github.com/prometheus/procfs v0.15.0 // indirect 70 | github.com/sagikazarmark/locafero v0.4.0 // indirect 71 | github.com/sagikazarmark/slog-shim v0.1.0 // indirect 72 | github.com/sethvargo/go-retry v0.2.4 // indirect 73 | github.com/sirupsen/logrus v1.9.3 // indirect 74 | github.com/sourcegraph/conc v0.3.0 // indirect 75 | github.com/spf13/afero v1.11.0 // indirect 76 | github.com/spf13/cast v1.6.0 // indirect 77 | github.com/spf13/cobra v1.8.0 // indirect 78 | github.com/spf13/pflag v1.0.5 // indirect 79 | github.com/subosito/gotenv v1.6.0 // indirect 80 | github.com/visualfc/goembed v0.3.3 // indirect 81 | go.uber.org/atomic v1.11.0 // indirect 82 | go.uber.org/multierr v1.11.0 // indirect 83 | golang.org/x/crypto v0.23.0 // indirect 84 | golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d // indirect 85 | golang.org/x/sync v0.7.0 // indirect 86 | golang.org/x/sys v0.20.0 // indirect 87 | golang.org/x/term v0.20.0 // indirect 88 | golang.org/x/text v0.15.0 // indirect 89 | golang.org/x/tools v0.21.0 // indirect 90 | google.golang.org/genproto/googleapis/api v0.0.0-20240521202816-d264139d666e // indirect 91 | google.golang.org/genproto/googleapis/rpc v0.0.0-20240521202816-d264139d666e // indirect 92 | google.golang.org/grpc v1.64.0 // indirect 93 | gopkg.in/ini.v1 v1.67.0 // indirect 94 | gopkg.in/jcmturner/aescts.v1 v1.0.1 // indirect 95 | gopkg.in/jcmturner/dnsutils.v1 v1.0.1 // indirect 96 | gopkg.in/jcmturner/gokrb5.v6 v6.1.1 // indirect 97 | gopkg.in/jcmturner/rpc.v1 v1.1.0 // indirect 98 | gopkg.in/yaml.v3 v3.0.1 // indirect 99 | ) 100 | -------------------------------------------------------------------------------- /internal/boot/boot.go: -------------------------------------------------------------------------------- 1 | package boot 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "os" 7 | "strings" 8 | 9 | "github.com/dlmiddlecote/sqlstats" 10 | "github.com/fatih/structs" 11 | "github.com/prometheus/client_golang/prometheus" 12 | "github.com/razorpay/trino-gateway/internal/config" 13 | "github.com/razorpay/trino-gateway/internal/constants/contextkeys" 14 | config_reader "github.com/razorpay/trino-gateway/pkg/config" 15 | "github.com/razorpay/trino-gateway/pkg/logger" 16 | "github.com/razorpay/trino-gateway/pkg/spine/db" 17 | "github.com/rs/xid" 18 | ) 19 | 20 | const ( 21 | requestIDHttpHeaderKey = "X-Request-ID" 22 | requestIDCtxKey = "RequestID" 23 | ) 24 | 25 | var ( 26 | // Config contains application configuration values. 27 | Config config.Config 28 | 29 | // DB holds the application db connection. 30 | DB *db.DB 31 | ) 32 | 33 | func init() { 34 | // Init config 35 | err := config_reader.NewDefaultConfig().Load(GetEnv(), &Config) 36 | if err != nil { 37 | log.Fatal(err) 38 | } 39 | 40 | InitLogger(context.Background()) 41 | 42 | // Init Db 43 | DB, err = db.NewDb(&Config.Db) 44 | if err != nil { 45 | log.Fatal(err.Error()) 46 | } 47 | } 48 | 49 | // Fetch env for bootstrapping 50 | func GetEnv() string { 51 | environment := os.Getenv("APP_ENV") 52 | if environment == "" { 53 | log.Print("APP_ENV not set defaulting to dev env.", environment) 54 | environment = "dev" 55 | } 56 | 57 | log.Print("Setting app env to ", environment) 58 | 59 | return environment 60 | } 61 | 62 | // GetRequestID gets the request id 63 | // if its already set in the given context 64 | // if there is no requestID set then it'll create a new 65 | // request id and returns the same 66 | func GetRequestID(ctx context.Context) string { 67 | if val, ok := ctx.Value(contextkeys.RequestID).(string); ok { 68 | return val 69 | } 70 | return xid.New().String() 71 | } 72 | 73 | // WithRequestID adds a request if to the context and gives the updated context back 74 | // if the passed requestID is empty then creates one by itself 75 | func WithRequestID(ctx context.Context, requestID string) context.Context { 76 | if requestID == "" { 77 | requestID = xid.New().String() 78 | } 79 | 80 | return context.WithValue(ctx, contextkeys.RequestID, requestID) 81 | } 82 | 83 | // initialize all core dependencies for the application 84 | func initialize(ctx context.Context, env string) error { 85 | log := InitLogger(ctx) 86 | 87 | ctx = context.WithValue(ctx, logger.LoggerCtxKey, log) 88 | 89 | // Puts git commit hash into config. 90 | // This is not read automatically because env variable is not in expected format. 91 | if v, found := os.LookupEnv("GIT_COMMIT_HASH"); found { 92 | Config.App.GitCommitHash = v 93 | } 94 | 95 | // Register DB stats prometheus collector 96 | dbInstance, err := DB.Instance(ctx).DB() 97 | if err != nil { 98 | return err 99 | } 100 | collector := sqlstats.NewStatsCollector(Config.Db.URL+"-"+Config.Db.Name, dbInstance) 101 | prometheus.MustRegister(collector) 102 | 103 | return nil 104 | } 105 | 106 | func InitApi(ctx context.Context, env string) error { 107 | err := initialize(ctx, env) 108 | if err != nil { 109 | return err 110 | } 111 | 112 | return nil 113 | } 114 | 115 | func InitMigration(ctx context.Context, env string) error { 116 | err := initialize(ctx, env) 117 | if err != nil { 118 | return err 119 | } 120 | 121 | return nil 122 | } 123 | 124 | // // InitTracing initialises opentracing exporter 125 | // func InitTracing(ctx context.Context) (io.Closer, error) { 126 | // t, closer, err := tracing.Init(Config.Tracing, Logger(ctx)) 127 | 128 | // Tracer = t 129 | 130 | // return closer, err 131 | // } 132 | 133 | // NewContext adds core key-value e.g. service name, git hash etc to 134 | // existing context or to a new background context and returns. 135 | func NewContext(ctx context.Context) context.Context { 136 | if ctx == nil { 137 | ctx = context.Background() 138 | } 139 | for k, v := range structs.Map(struct { 140 | GitCommitHash string 141 | Env string 142 | ServiceName string 143 | }{ 144 | GitCommitHash: Config.App.GitCommitHash, 145 | Env: Config.App.Env, 146 | ServiceName: Config.App.ServiceName, 147 | }) { 148 | key := strings.ToLower(k) 149 | ctx = context.WithValue(ctx, key, v) 150 | } 151 | return ctx 152 | } 153 | 154 | func InitLogger(ctx context.Context) *logger.ZapLogger { 155 | lgrConfig := logger.Config{ 156 | LogLevel: Config.App.LogLevel, 157 | ContextString: "trino-gateway", 158 | } 159 | 160 | Logger, err := logger.NewLogger(lgrConfig) 161 | if err != nil { 162 | panic("failed to initialize logger") 163 | } 164 | 165 | return Logger 166 | } 167 | -------------------------------------------------------------------------------- /internal/config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/razorpay/trino-gateway/pkg/spine/db" 5 | ) 6 | 7 | type Config struct { 8 | App App 9 | Auth Auth 10 | Db db.Config 11 | Gateway Gateway 12 | Monitor Monitor 13 | } 14 | 15 | // App contains application-specific config values 16 | type App struct { 17 | Env string 18 | GitCommitHash string 19 | LogLevel string 20 | MetricsPort int 21 | Port int 22 | ServiceExternalHostname string 23 | ServiceHostname string 24 | ServiceName string 25 | ShutdownDelay int 26 | ShutdownTimeout int 27 | } 28 | 29 | type Auth struct { 30 | Token string 31 | TokenHeaderKey string 32 | Router struct { 33 | DelegatedAuth struct { 34 | ValidationProviderURL string 35 | ValidationProviderToken string 36 | CacheTTLMinutes string 37 | } 38 | } 39 | } 40 | 41 | type Gateway struct { 42 | DefaultRoutingGroup string 43 | Ports []int 44 | Network string 45 | } 46 | 47 | type Monitor struct { 48 | Interval string 49 | StatsValiditySecs int 50 | Trino struct { 51 | User string 52 | Password string 53 | } 54 | HealthCheckSql string 55 | } 56 | -------------------------------------------------------------------------------- /internal/constants/contextkeys/contextkeys.go: -------------------------------------------------------------------------------- 1 | package contextkeys 2 | 3 | type contextkeys int 4 | 5 | const RequestID contextkeys = iota 6 | -------------------------------------------------------------------------------- /internal/frontend/README.md: -------------------------------------------------------------------------------- 1 | ### Notes 2 | Gopherjs compiles to js to run entirely in client browser 3 | unlike java nashorn or similar engine in jvm which allows running js code 4 | we dont hav that luxury, 5 | as a result the gatewayserver API and UI need to reside under same endpoint or else 6 | due to CORS https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS/Errors/CORSMissingAllowOrigin 7 | browsers won't allow connectivity between both. 8 | This also solves the problem of maintaining state for UI as its completely client side 9 | 10 | ### Running in dev mode for testing 11 | 12 | The easiest way to run the examples as WebAssembly is via [`wasmserve`](https://github.com/hajimehoshi/wasmserve). 13 | 14 | Install it (**using Go 1.14+**): 15 | 16 | ```bash 17 | go get -u github.com/hajimehoshi/wasmserve 18 | ``` 19 | 20 | Then run an example: 21 | 22 | ```bash 23 | cd trino-gateway/internal/frontend 24 | wasmserve 25 | ``` 26 | 27 | Then navigate to http://localhost:8080/ -------------------------------------------------------------------------------- /internal/frontend/components/adminview.go: -------------------------------------------------------------------------------- 1 | package components 2 | -------------------------------------------------------------------------------- /internal/frontend/components/dashboardview.go: -------------------------------------------------------------------------------- 1 | package components 2 | -------------------------------------------------------------------------------- /internal/frontend/components/pageview.go: -------------------------------------------------------------------------------- 1 | package components 2 | 3 | import ( 4 | "github.com/hexops/vecty" 5 | "github.com/hexops/vecty/elem" 6 | "github.com/razorpay/trino-gateway/internal/frontend/core" 7 | ) 8 | 9 | // PageView is a vecty.Component which represents the entire page. 10 | type PageView struct { 11 | vecty.Core 12 | core core.ICore 13 | } 14 | 15 | func GetNewPageViewComponent(c core.ICore) *PageView { 16 | return &PageView{core: c} 17 | } 18 | 19 | func (p *PageView) Render() vecty.ComponentOrHTML { 20 | return elem.Body( 21 | elem.Section( 22 | vecty.Markup( 23 | vecty.Class("section"), 24 | ), 25 | elem.Div( 26 | vecty.Markup( 27 | vecty.Class("container"), 28 | ), 29 | p.renderHeader(), 30 | NewQueryListView(p.core), 31 | p.renderFooter(), 32 | ), 33 | ), 34 | ) 35 | } 36 | 37 | func (p *PageView) renderHeader() *vecty.HTML { 38 | return elem.Div( 39 | vecty.Markup( 40 | vecty.Class("tabs", "is-centered", "is-large", "is-fullwidth", "is-toggle", "is-toggle-rounded"), 41 | ), 42 | elem.UnorderedList( 43 | elem.ListItem( 44 | vecty.Markup( 45 | vecty.Class("is-active"), 46 | ), 47 | &TabView{title: "Query History", hrefUrl: "#"}, 48 | ), 49 | elem.ListItem(&TabView{title: "Dashboard", hrefUrl: "#"}), 50 | elem.ListItem(&TabView{title: "Admin", hrefUrl: "/admin/swaggerui"}), 51 | ), 52 | ) 53 | } 54 | 55 | func (p *PageView) renderFooter() *vecty.HTML { 56 | return elem.Footer( 57 | vecty.Markup( 58 | vecty.Class("footer"), 59 | ), 60 | elem.Div( 61 | vecty.Markup( 62 | vecty.Class("content", "has-text-centered"), 63 | ), 64 | vecty.Text("trino gateway footer"), 65 | ), 66 | ) 67 | } 68 | -------------------------------------------------------------------------------- /internal/frontend/components/querylistview.go: -------------------------------------------------------------------------------- 1 | package components 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | 7 | "github.com/hexops/vecty" 8 | "github.com/hexops/vecty/elem" 9 | "github.com/hexops/vecty/event" 10 | "github.com/hexops/vecty/prop" 11 | "github.com/razorpay/trino-gateway/internal/frontend/core" 12 | ) 13 | 14 | // QueryListView is a vecty.Component which represents the query history section 15 | type queryListView struct { 16 | vecty.Core 17 | core core.ICore 18 | items vecty.List 19 | params queryListViewParams 20 | } 21 | 22 | type queryListViewParams struct { 23 | Username string 24 | MaxItems int 25 | TotalItems int 26 | PageIndex int 27 | ItemsPerPage int 28 | } 29 | 30 | func NewQueryListView(core core.ICore) *queryListView { 31 | p := &queryListView{ 32 | core: core, 33 | params: queryListViewParams{ 34 | PageIndex: 0, 35 | ItemsPerPage: 100, 36 | MaxItems: 1000, 37 | }, 38 | } 39 | 40 | if err := p.populateItems(); err != nil { 41 | fmt.Printf("%s: %s\n", "Unable to fetch list of queries.", err.Error()) 42 | } 43 | return p 44 | } 45 | 46 | func (p *queryListView) populateItems() error { 47 | // Get total eligible items 48 | queries, err := p.core.GetQueries(p.params.MaxItems, 0, p.params.Username) 49 | if err != nil { 50 | return err 51 | } 52 | p.params.TotalItems = len(queries) 53 | // queries, err = p.core.GetQueries(p.params.ItemsPerPage, p.params.PageIndex*p.params.ItemsPerPage, p.params.Username) 54 | // if err != nil { 55 | // return err 56 | // } 57 | for _, q := range queries { 58 | query := &QueryView{ 59 | Query: q, 60 | } 61 | p.items = append(p.items, query) 62 | } 63 | return nil 64 | } 65 | 66 | func (p *queryListView) Render() vecty.ComponentOrHTML { 67 | return elem.Div( 68 | vecty.Markup( 69 | vecty.Class("container", "tile", "is-vertical", "is-ancestor"), 70 | ), 71 | p.renderHeader(), 72 | p.renderItems(), 73 | p.renderPagination(), 74 | ) 75 | } 76 | 77 | func (p *queryListView) renderHeader() vecty.ComponentOrHTML { 78 | return elem.Div( 79 | vecty.Markup( 80 | vecty.Class("tile", "is-parent"), 81 | ), 82 | elem.Div( 83 | vecty.Markup( 84 | vecty.Class("tile", "is-child"), 85 | ), 86 | vecty.Text(fmt.Sprintf("Total: %d", p.params.TotalItems)), 87 | ), 88 | elem.Div( 89 | vecty.Markup( 90 | vecty.Class("tile", "is-child"), 91 | ), 92 | elem.Input( 93 | vecty.Markup( 94 | vecty.Class("input"), 95 | prop.Type(prop.TypeText), 96 | prop.Placeholder("Search for Username"), // initial textarea text. 97 | ), 98 | 99 | // When input is typed into the textarea, update the local 100 | // component state and rerender. 101 | // event.Input(func(e *vecty.Event) { 102 | // p.Input = e.Target.Get("value").String() 103 | // vecty.Rerender(p) 104 | // }), 105 | ), 106 | ), 107 | elem.Div( 108 | vecty.Markup( 109 | vecty.Class("tile", "is-child"), 110 | ), 111 | elem.Div( 112 | vecty.Markup( 113 | vecty.Class("select", "is-rounded"), 114 | ), 115 | elem.Select( 116 | elem.Option(vecty.Text("100"+" Entries per page")), 117 | elem.Option(vecty.Text("200"+" Entries per page")), 118 | elem.Option(vecty.Text("500"+" Entries per page")), 119 | ), 120 | ), 121 | 122 | //
123 | // 137 | //
138 | ), 139 | ) 140 | } 141 | 142 | func (p *queryListView) onEditSearchbox(e *vecty.Event) { 143 | } 144 | 145 | func (p *queryListView) onClickPageNavigation(_ *vecty.Event) { 146 | p.params.PageIndex = p.params.PageIndex + 1 147 | vecty.Rerender(p) 148 | } 149 | 150 | func (p *queryListView) renderItems() vecty.ComponentOrHTML { 151 | r := vecty.List{} 152 | for i, v := range p.items { 153 | if i >= p.params.PageIndex*p.params.ItemsPerPage && i < (p.params.PageIndex+1)*p.params.ItemsPerPage { 154 | r = append(r, v) 155 | } 156 | } 157 | return elem.OrderedList(r) 158 | } 159 | 160 | func (p *queryListView) renderPagination() vecty.ComponentOrHTML { 161 | totPag := int(math.Ceil(float64(p.params.TotalItems) / float64(p.params.ItemsPerPage))) 162 | currPag := p.params.PageIndex + 1 163 | 164 | return elem.Div( 165 | vecty.Markup( 166 | vecty.Class("tile", "is-parent"), 167 | ), 168 | elem.Div( 169 | vecty.Markup( 170 | vecty.Class("tile", "is-child"), 171 | ), 172 | elem.Navigation( 173 | vecty.Markup( 174 | vecty.Class("pagination", "is-rounded"), 175 | vecty.Property("role", "navigation"), 176 | vecty.Property("aria-label", "pagination"), 177 | ), 178 | elem.Anchor( 179 | vecty.Markup( 180 | vecty.MarkupIf(currPag == 1, vecty.Style("display", "none")), 181 | vecty.Class("pagination-previous"), 182 | ), 183 | vecty.Text("Previous"), 184 | ), 185 | elem.Anchor( 186 | vecty.Markup( 187 | vecty.MarkupIf(currPag == totPag, vecty.Style("display", "none")), 188 | vecty.Class("pagination-next"), 189 | event.Click(p.onClickPageNavigation).PreventDefault(), 190 | ), 191 | vecty.Text("Next page"), 192 | ), 193 | elem.UnorderedList( 194 | vecty.Markup( 195 | vecty.Class("pagination-list"), 196 | ), 197 | elem.ListItem(elem.Anchor( 198 | vecty.Markup( 199 | vecty.MarkupIf(currPag == 1, vecty.Style("display", "none")), 200 | vecty.Class("pagination-link"), 201 | vecty.Property("aria-label", "Goto page 1"), 202 | ), 203 | vecty.Text("1"), 204 | )), 205 | elem.ListItem(elem.Span( 206 | vecty.Markup( 207 | vecty.MarkupIf(currPag == 1, vecty.Style("display", "none")), 208 | vecty.Class("pagination-ellipsis"), 209 | ), 210 | vecty.Text("..."), 211 | )), 212 | elem.ListItem(elem.Anchor( 213 | vecty.Markup( 214 | vecty.Class("pagination-link", "is-current"), 215 | vecty.Property("aria-label", fmt.Sprint("Goto page ", currPag)), 216 | ), 217 | vecty.Text(fmt.Sprint(currPag)), 218 | )), 219 | elem.ListItem(elem.Span( 220 | vecty.Markup( 221 | vecty.MarkupIf(currPag == totPag, vecty.Style("display", "none")), 222 | vecty.Class("pagination-ellipsis"), 223 | ), 224 | vecty.Text("..."), 225 | )), 226 | elem.ListItem(elem.Anchor( 227 | vecty.Markup( 228 | vecty.MarkupIf(currPag == totPag, vecty.Style("display", "none")), 229 | vecty.Class("pagination-link"), 230 | vecty.Property("aria-label", fmt.Sprint("Goto page ", totPag)), 231 | ), 232 | vecty.Text(fmt.Sprint(totPag)), 233 | )), 234 | ), 235 | ), 236 | ), 237 | ) 238 | } 239 | -------------------------------------------------------------------------------- /internal/frontend/components/queryview.go: -------------------------------------------------------------------------------- 1 | package components 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/hexops/vecty" 8 | "github.com/hexops/vecty/elem" 9 | "github.com/hexops/vecty/event" 10 | "github.com/hexops/vecty/prop" 11 | "github.com/hexops/vecty/style" 12 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway" 13 | ) 14 | 15 | // QueryView is a vecty.Component which represents a single item in the queryHistory List. 16 | type QueryView struct { 17 | vecty.Core 18 | // core core.ICore 19 | 20 | Query *gatewayv1.Query 21 | classes vecty.ClassMap 22 | } 23 | 24 | /* 25 | QueryId - href to trinoUi 26 | User 27 | backendId 28 | GroupId 29 | submittedAt 30 | Text 31 | */ 32 | 33 | func (p *QueryView) Render() vecty.ComponentOrHTML { 34 | return elem.Div( 35 | vecty.Markup( 36 | vecty.Class("box", "tile", "is-parent", "notification", "is-light"), 37 | vecty.Style("display", "flex"), 38 | vecty.Style("flex-direction", "row"), 39 | ), 40 | p.renderMeta(), 41 | p.renderText(), 42 | ) 43 | } 44 | 45 | type QueryMetaItem struct { 46 | vecty.Core 47 | k string 48 | v string 49 | } 50 | 51 | // TODO : FIX it 52 | var classes = vecty.ClassMap{ 53 | "is-info": true, 54 | "is-light": true, 55 | "is-link": false, 56 | } 57 | 58 | func (q *QueryMetaItem) Render() vecty.ComponentOrHTML { 59 | return vecty.Text(q.v) 60 | } 61 | 62 | func (p *QueryView) renderMeta() *vecty.HTML { 63 | // https://trino-gateway.de.razorpay.com/ui/query.html?20211003_083931_06657_n3mb3 64 | url := fmt.Sprintf("%s/ui/query.html?%s", p.Query.ServerHost, p.Query.Id) 65 | 66 | // fails at runtime in js 67 | // loc, _ := time.LoadLocation("Asia/Kolkata") 68 | 69 | _items := [...]*QueryMetaItem{ 70 | { 71 | k: "SubmittedAt", 72 | v: time.Unix(p.Query.GetSubmittedAt(), 0).Local().Format("2006/01/02 15:04:05"), // Golang time format layout is weird stuff. 73 | }, 74 | {k: "Username", v: p.Query.GetUsername()}, 75 | {k: "BackendId", v: p.Query.GetBackendId()}, 76 | {k: "GroupId", v: p.Query.GetGroupId()}, 77 | } 78 | 79 | var items vecty.List 80 | for _, i := range _items { 81 | item := elem.ListItem(i) 82 | items = append(items, item) 83 | } 84 | 85 | p.classes = classes 86 | 87 | return elem.Div( 88 | vecty.Markup( 89 | p.classes, 90 | vecty.Class("tile", "is-child", "notification"), 91 | style.Width("30%"), 92 | vecty.Style("display", "flex"), 93 | vecty.Style("flex-direction", "column"), 94 | event.PointerEnter(p.onPointerEnter), 95 | event.PointerLeave(p.onPointerLeave), 96 | ), 97 | elem.Bold( 98 | vecty.Markup( 99 | vecty.Class("subtitle"), 100 | ), 101 | elem.Anchor( 102 | vecty.Markup( 103 | prop.Href(url), 104 | ), 105 | vecty.Text(p.Query.GetId()), 106 | ), 107 | ), 108 | elem.UnorderedList(items), 109 | ) 110 | } 111 | 112 | func (p *QueryView) renderText() *vecty.HTML { 113 | return elem.Div( 114 | vecty.Markup( 115 | vecty.Class("tile", "is-child", "is-8"), 116 | style.Width("70%"), 117 | style.Height("6.5em"), 118 | vecty.Style("word-wrap", "break-word"), 119 | style.Overflow(style.OverflowHidden), 120 | vecty.Style("text-overflow", "ellipsis"), 121 | vecty.Style("resize", "vertical"), 122 | vecty.Style("text-align", "center"), 123 | ), 124 | vecty.Text(p.Query.GetText()), 125 | ) 126 | } 127 | 128 | func (p *QueryView) onPointerEnter(e *vecty.Event) { 129 | p.classes["is-info"] = false 130 | p.classes["is-link"] = true 131 | vecty.Rerender(p) 132 | } 133 | 134 | func (p *QueryView) onPointerLeave(e *vecty.Event) { 135 | p.classes["is-info"] = true 136 | p.classes["is-link"] = false 137 | vecty.Rerender(p) 138 | } 139 | -------------------------------------------------------------------------------- /internal/frontend/components/tabview.go: -------------------------------------------------------------------------------- 1 | package components 2 | 3 | import ( 4 | "github.com/hexops/vecty" 5 | "github.com/hexops/vecty/elem" 6 | "github.com/hexops/vecty/prop" 7 | ) 8 | 9 | // TabView is a vecty.Component which represents a single elements in the tabBar 10 | type TabView struct { 11 | vecty.Core 12 | title string 13 | isSelected bool 14 | // TODO: remove this 15 | hrefUrl string 16 | component *vecty.ComponentOrHTML 17 | } 18 | 19 | func (p *TabView) Render() vecty.ComponentOrHTML { 20 | return elem.Anchor( 21 | vecty.Markup( 22 | vecty.MarkupIf(p.isSelected, vecty.Class("is-active")), 23 | prop.Href(p.hrefUrl), 24 | // event.Click(p.onClick).PreventDefault(), 25 | ), 26 | vecty.Text(p.title), 27 | ) 28 | } 29 | 30 | func (p *TabView) onClick(e *vecty.Event) { 31 | } 32 | -------------------------------------------------------------------------------- /internal/frontend/core/core.go: -------------------------------------------------------------------------------- 1 | package core 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | 9 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway" 10 | ) 11 | 12 | type GatewayApiClient struct { 13 | Policy gatewayv1.PolicyApi 14 | Backend gatewayv1.BackendApi 15 | Group gatewayv1.GroupApi 16 | Query gatewayv1.QueryApi 17 | } 18 | 19 | type Core struct { 20 | gatewayApiClient *GatewayApiClient 21 | } 22 | 23 | type ICore interface { 24 | GetQueries(count int, skip int, user string) ([]*gatewayv1.Query, error) 25 | } 26 | 27 | func NewCore(gatewayHost string) *Core { 28 | return &Core{ 29 | gatewayApiClient: &GatewayApiClient{ 30 | Backend: gatewayv1.NewBackendApiProtobufClient(gatewayHost, &http.Client{}), 31 | Group: gatewayv1.NewGroupApiProtobufClient(gatewayHost, &http.Client{}), 32 | Policy: gatewayv1.NewPolicyApiProtobufClient(gatewayHost, &http.Client{}), 33 | Query: gatewayv1.NewQueryApiProtobufClient(gatewayHost, &http.Client{}), 34 | }, 35 | } 36 | } 37 | 38 | func (c *Core) GetQueries(count int, skip int, user string) ([]*gatewayv1.Query, error) { 39 | req := gatewayv1.QueriesListRequest{ 40 | Skip: int32(skip), 41 | Count: int32(count), 42 | Username: user, 43 | } 44 | queriesResp, err := c.gatewayApiClient.Query.ListQueries(context.Background(), &req) 45 | if err != nil { 46 | println(err.Error()) 47 | return nil, errors.New(fmt.Sprint("Unable to Fetch list of queries", err.Error())) 48 | } 49 | 50 | return queriesResp.GetItems(), nil 51 | } 52 | -------------------------------------------------------------------------------- /internal/frontend/dispatcher/dispatcher.go: -------------------------------------------------------------------------------- 1 | package dispatcher 2 | 3 | // ID is a unique identifier representing a registered callback function. 4 | type ID int 5 | 6 | var ( 7 | idCounter ID 8 | callbacks = make(map[ID]func(action interface{})) 9 | ) 10 | 11 | // Dispatch dispatches the given action to all registered callbacks. 12 | func Dispatch(action interface{}) { 13 | for _, c := range callbacks { 14 | c(action) 15 | } 16 | } 17 | 18 | // Register registers the callback to handle dispatched actions, the returned 19 | // ID may be used to unregister the callback later. 20 | func Register(callback func(action interface{})) ID { 21 | idCounter++ 22 | id := idCounter 23 | callbacks[id] = callback 24 | return id 25 | } 26 | 27 | // Unregister unregisters the callback previously registered via a call to 28 | // Register. 29 | func Unregister(id ID) { 30 | delete(callbacks, id) 31 | } 32 | -------------------------------------------------------------------------------- /internal/frontend/main.go: -------------------------------------------------------------------------------- 1 | // gopherjs doesnt build with anythign other than package `main` 2 | package main 3 | 4 | import ( 5 | "github.com/gopherjs/gopherjs/js" 6 | "github.com/hexops/vecty" 7 | "github.com/razorpay/trino-gateway/internal/frontend/components" 8 | "github.com/razorpay/trino-gateway/internal/frontend/core" 9 | ) 10 | 11 | func main() { 12 | path := accessURL() // fmt.Sprint("http://localhost:", "28000") 13 | c := core.NewCore(path) 14 | 15 | vecty.SetTitle("Trino-Gateway") 16 | // vecty.AddStylesheet("https://rawgit.com/tastejs/todomvc-common/master/base.css") 17 | // vecty.AddStylesheet("https://rawgit.com/tastejs/todomvc-app-css/master/index.css") 18 | vecty.AddStylesheet("https://cdn.jsdelivr.net/npm/bulma@0.9.3/css/bulma.min.css") 19 | 20 | vecty.RenderBody(components.GetNewPageViewComponent(c)) 21 | } 22 | 23 | var location = js.Global.Get("location") 24 | 25 | func accessURL() string { 26 | // current URL: http://localhost:8000/code/gopherjs/window-location/index.html?a=1 27 | 28 | // return - http://localhost:8000/code/gopherjs/window-location/index.html?a=1 29 | location.Get("href").String() 30 | // return - localhost:8000 31 | location.Get("host").String() 32 | // return - localhost 33 | location.Get("hostname").String() 34 | // return - /code/gopherjs/window-location/index.html 35 | location.Get("pathname").String() 36 | // return - http: 37 | location.Get("protocol").String() 38 | // return - http://localhost:8000 39 | location.Get("origin").String() 40 | // return - 8000 41 | location.Get("port").String() 42 | // return - ?a=1 43 | location.Get("search").String() 44 | 45 | return location.Get("origin").String() 46 | } 47 | -------------------------------------------------------------------------------- /internal/frontend/server/main.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "github.com/NYTimes/gziphandler" 8 | ) 9 | 10 | func NewServerHandler(ctx *context.Context) *http.Handler { 11 | guiFs := http.FileServer(http.Dir("./web/frontend")) 12 | appFrontendPath := "/" 13 | h := cacheHandler( 14 | compressionHandler( 15 | http.StripPrefix(appFrontendPath, guiFs), 16 | ), 17 | ) 18 | 19 | return &h 20 | } 21 | 22 | func cacheHandler(h http.Handler) http.Handler { 23 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 24 | w.Header().Set("Cache-Control", "max-age=180") 25 | h.ServeHTTP(w, r) 26 | }) 27 | } 28 | 29 | func compressionHandler(h http.Handler) http.Handler { 30 | // TODO: check https://github.com/CAFxX/httpcompression 31 | return gziphandler.GzipHandler(h) 32 | } 33 | -------------------------------------------------------------------------------- /internal/gatewayserver/backendApi/core.go: -------------------------------------------------------------------------------- 1 | package backendapi 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/fatih/structs" 7 | "github.com/razorpay/trino-gateway/internal/gatewayserver/models" 8 | "github.com/razorpay/trino-gateway/internal/gatewayserver/repo" 9 | ) 10 | 11 | type Core struct { 12 | backendRepo repo.IBackendRepo 13 | } 14 | 15 | type ICore interface { 16 | CreateOrUpdateBackend(ctx context.Context, params *BackendCreateParams) error 17 | GetBackend(ctx context.Context, id string) (*models.Backend, error) 18 | GetAllBackends(ctx context.Context) ([]models.Backend, error) 19 | GetAllActiveBackends(ctx context.Context) ([]models.Backend, error) 20 | UpdateBackend(ctx context.Context, b *models.Backend) error 21 | DeleteBackend(ctx context.Context, id string) error 22 | EnableBackend(ctx context.Context, id string) error 23 | DisableBackend(ctx context.Context, id string) error 24 | MarkHealthyBackend(ctx context.Context, id string) error 25 | MarkUnhealthyBackend(ctx context.Context, id string) error 26 | } 27 | 28 | func NewCore(backend repo.IBackendRepo) *Core { 29 | return &Core{backendRepo: backend} 30 | } 31 | 32 | // CreateParams has attributes that are required for backend.Create() 33 | type BackendCreateParams struct { 34 | ID string 35 | Hostname string 36 | Scheme string 37 | ExternalUrl string 38 | IsEnabled bool 39 | IsHealthy bool 40 | UptimeSchedule string 41 | ClusterLoad int32 42 | ThresholdClusterLoad int32 43 | StatsUpdatedAt int64 44 | } 45 | 46 | func (c *Core) CreateOrUpdateBackend(ctx context.Context, params *BackendCreateParams) error { 47 | backend := models.Backend{ 48 | Hostname: params.Hostname, 49 | Scheme: params.Scheme, 50 | ExternalUrl: ¶ms.ExternalUrl, 51 | IsEnabled: ¶ms.IsEnabled, 52 | IsHealthy: ¶ms.IsHealthy, 53 | UptimeSchedule: ¶ms.UptimeSchedule, 54 | ClusterLoad: ¶ms.ClusterLoad, 55 | ThresholdClusterLoad: ¶ms.ThresholdClusterLoad, 56 | StatsUpdatedAt: ¶ms.StatsUpdatedAt, 57 | } 58 | backend.ID = params.ID 59 | 60 | _, exists := c.backendRepo.Find(ctx, params.ID) 61 | if exists == nil { // update 62 | return c.backendRepo.Update(ctx, &backend) 63 | } else { // create 64 | return c.backendRepo.Create(ctx, &backend) 65 | } 66 | } 67 | 68 | func (c *Core) GetBackend(ctx context.Context, id string) (*models.Backend, error) { 69 | backend, err := c.backendRepo.Find(ctx, id) 70 | return backend, err 71 | } 72 | 73 | func (c *Core) UpdateBackend(ctx context.Context, b *models.Backend) error { 74 | _, exists := c.backendRepo.Find(ctx, b.ID) 75 | if exists != nil { 76 | return exists 77 | } 78 | return c.backendRepo.Update(ctx, b) 79 | } 80 | 81 | func (c *Core) GetAllBackends(ctx context.Context) ([]models.Backend, error) { 82 | backends, err := c.backendRepo.FindMany(ctx, make(map[string]interface{})) 83 | return backends, err 84 | } 85 | 86 | type IFindManyParams interface { 87 | // GetCount() int32 88 | // GetSkip() int32 89 | // GetFrom() int32 90 | // GetTo() int32 91 | 92 | // custom 93 | GetIsEnabled() bool 94 | } 95 | 96 | type FindManyParams struct { 97 | // pagination 98 | // Count int32 99 | // Skip int32 100 | // From int32 101 | // To int32 102 | 103 | // custom 104 | IsEnabled bool `json:"is_enabled"` 105 | } 106 | 107 | func (p *FindManyParams) GetIsEnabled() bool { 108 | return p.IsEnabled 109 | } 110 | 111 | func (c *Core) FindMany(ctx context.Context, params IFindManyParams) ([]models.Backend, error) { 112 | conditionStr := structs.New(params) 113 | // use the json tag name, so we can respect omitempty tags 114 | conditionStr.TagName = "json" 115 | conditions := conditionStr.Map() 116 | 117 | return c.backendRepo.FindMany(ctx, conditions) 118 | } 119 | 120 | func (c *Core) GetAllActiveBackends(ctx context.Context) ([]models.Backend, error) { 121 | backends, err := c.FindMany(ctx, &FindManyParams{IsEnabled: true}) 122 | return backends, err 123 | } 124 | 125 | func (c *Core) DeleteBackend(ctx context.Context, id string) error { 126 | return c.backendRepo.Delete(ctx, id) 127 | } 128 | 129 | func (c *Core) EnableBackend(ctx context.Context, id string) error { 130 | return c.backendRepo.Enable(ctx, id) 131 | } 132 | 133 | func (c *Core) DisableBackend(ctx context.Context, id string) error { 134 | return c.backendRepo.Disable(ctx, id) 135 | } 136 | 137 | func (c *Core) MarkHealthyBackend(ctx context.Context, id string) error { 138 | return c.backendRepo.MarkHealthy(ctx, id) 139 | } 140 | 141 | func (c *Core) MarkUnhealthyBackend(ctx context.Context, id string) error { 142 | return c.backendRepo.MarkUnhealthy(ctx, id) 143 | } 144 | 145 | type EvaluateClientParams struct { 146 | ListeningPort int32 147 | } 148 | -------------------------------------------------------------------------------- /internal/gatewayserver/backendApi/server.go: -------------------------------------------------------------------------------- 1 | package backendapi 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "time" 8 | 9 | _ "github.com/twitchtv/twirp" 10 | 11 | "github.com/razorpay/trino-gateway/internal/gatewayserver/models" 12 | "github.com/razorpay/trino-gateway/internal/provider" 13 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway" 14 | ) 15 | 16 | // Server has methods implementing of server rpc. 17 | type Server struct { 18 | core ICore 19 | } 20 | 21 | // NewServer returns a server. 22 | func NewServer(core ICore) *Server { 23 | return &Server{ 24 | core: core, 25 | } 26 | } 27 | 28 | // Create creates a new backend 29 | func (s *Server) CreateOrUpdateBackend(ctx context.Context, req *gatewayv1.Backend) (*gatewayv1.Empty, error) { 30 | // defer span.Finish() 31 | 32 | provider.Logger(ctx).Debugw("CreateOrUpdateBackend", map[string]interface{}{ 33 | "request": req.String(), 34 | }) 35 | 36 | createParams := BackendCreateParams{ 37 | ID: req.GetId(), 38 | Scheme: req.GetScheme().Enum().String(), 39 | Hostname: req.GetHostname(), 40 | ExternalUrl: req.GetExternalUrl(), 41 | IsEnabled: req.GetIsEnabled(), 42 | IsHealthy: req.GetIsHealthy(), 43 | UptimeSchedule: req.GetUptimeSchedule(), 44 | ClusterLoad: req.GetClusterLoad(), 45 | ThresholdClusterLoad: req.GetThresholdClusterLoad(), 46 | StatsUpdatedAt: req.GetStatsUpdatedAt(), 47 | } 48 | 49 | err := s.core.CreateOrUpdateBackend(ctx, &createParams) 50 | if err != nil { 51 | return nil, err 52 | } 53 | 54 | return &gatewayv1.Empty{}, nil 55 | } 56 | 57 | // Get retrieves a single backend record 58 | func (s *Server) GetBackend(ctx context.Context, req *gatewayv1.BackendGetRequest) (*gatewayv1.BackendGetResponse, error) { 59 | provider.Logger(ctx).Debugw("GetBackend", map[string]interface{}{ 60 | "request": req.String(), 61 | }) 62 | 63 | backend, err := s.core.GetBackend(ctx, req.GetId()) 64 | if err != nil { 65 | return nil, err 66 | } 67 | backendProto, err := toBackendResponseProto(backend) 68 | if err != nil { 69 | return nil, err 70 | } 71 | return &gatewayv1.BackendGetResponse{Backend: backendProto}, nil 72 | } 73 | 74 | // List fetches a list of filtered backend records 75 | func (s *Server) ListAllBackends(ctx context.Context, req *gatewayv1.Empty) (*gatewayv1.BackendListAllResponse, error) { 76 | provider.Logger(ctx).Debugw("ListAllBackends", map[string]interface{}{ 77 | "request": req.String(), 78 | }) 79 | backends, err := s.core.GetAllBackends(ctx) 80 | if err != nil { 81 | return nil, err 82 | } 83 | 84 | backendsProto := make([]*gatewayv1.Backend, len(backends)) 85 | for i, backendModel := range backends { 86 | backend, err := toBackendResponseProto(&backendModel) 87 | if err != nil { 88 | return nil, err 89 | } 90 | backendsProto[i] = backend 91 | } 92 | 93 | response := gatewayv1.BackendListAllResponse{ 94 | Items: backendsProto, 95 | } 96 | 97 | return &response, nil 98 | } 99 | 100 | // Approve marks a backends status to approved 101 | 102 | func (s *Server) EnableBackend(ctx context.Context, req *gatewayv1.BackendEnableRequest) (*gatewayv1.Empty, error) { 103 | provider.Logger(ctx).Debugw("EnableBackend", map[string]interface{}{ 104 | "request": req.String(), 105 | }) 106 | err := s.core.EnableBackend(ctx, req.GetId()) 107 | if err != nil { 108 | return nil, err 109 | } 110 | 111 | return &gatewayv1.Empty{}, nil 112 | } 113 | 114 | func (s *Server) DisableBackend(ctx context.Context, req *gatewayv1.BackendDisableRequest) (*gatewayv1.Empty, error) { 115 | provider.Logger(ctx).Debugw("DisableBackend", map[string]interface{}{ 116 | "request": req.String(), 117 | }) 118 | err := s.core.DisableBackend(ctx, req.GetId()) 119 | if err != nil { 120 | return nil, err 121 | } 122 | 123 | return &gatewayv1.Empty{}, nil 124 | } 125 | 126 | func (s *Server) MarkHealthyBackend(ctx context.Context, req *gatewayv1.BackendMarkHealthyRequest) (*gatewayv1.Empty, error) { 127 | provider.Logger(ctx).Debugw("MarkHealthyBackend", map[string]interface{}{ 128 | "request": req.String(), 129 | }) 130 | err := s.core.MarkHealthyBackend(ctx, req.GetId()) 131 | if err != nil { 132 | return nil, err 133 | } 134 | 135 | return &gatewayv1.Empty{}, nil 136 | } 137 | 138 | func (s *Server) MarkUnhealthyBackend(ctx context.Context, req *gatewayv1.BackendMarkUnhealthyRequest) (*gatewayv1.Empty, error) { 139 | provider.Logger(ctx).Debugw("MarkUnhealthyBackend", map[string]interface{}{ 140 | "request": req.String(), 141 | }) 142 | err := s.core.MarkUnhealthyBackend(ctx, req.GetId()) 143 | if err != nil { 144 | return nil, err 145 | } 146 | 147 | return &gatewayv1.Empty{}, nil 148 | } 149 | 150 | func (s *Server) UpdateClusterLoadBackend( 151 | ctx context.Context, 152 | req *gatewayv1.BackendUpdateClusterLoadRequest, 153 | ) (*gatewayv1.Empty, error) { 154 | provider.Logger(ctx).Debugw("UpdateClusterLoadBackend", map[string]interface{}{ 155 | "request": req.String(), 156 | }) 157 | b, err := s.core.GetBackend(ctx, req.GetId()) 158 | if err != nil { 159 | return nil, err 160 | } 161 | *b.ClusterLoad = req.GetClusterLoad() 162 | *b.StatsUpdatedAt = time.Now().Unix() 163 | 164 | if err := s.core.UpdateBackend(ctx, b); err != nil { 165 | return nil, err 166 | } 167 | 168 | return &gatewayv1.Empty{}, nil 169 | } 170 | 171 | // Delete deletes a backend, soft-delete 172 | func (s *Server) DeleteBackend(ctx context.Context, req *gatewayv1.BackendDeleteRequest) (*gatewayv1.Empty, error) { 173 | provider.Logger(ctx).Debugw("DeleteBackend", map[string]interface{}{ 174 | "request": req.String(), 175 | }) 176 | err := s.core.DeleteBackend(ctx, req.GetId()) 177 | if err != nil { 178 | return nil, err 179 | } 180 | 181 | return &gatewayv1.Empty{}, nil 182 | } 183 | 184 | func toBackendResponseProto(backend *models.Backend) (*gatewayv1.Backend, error) { 185 | if backend == nil { 186 | return &gatewayv1.Backend{}, nil 187 | } 188 | scheme, ok := gatewayv1.Backend_Scheme_value[backend.Scheme] 189 | if !ok { 190 | return nil, errors.New(fmt.Sprint("error encoding response: invalid scheme ", backend.Scheme)) 191 | } 192 | response := gatewayv1.Backend{ 193 | Id: backend.ID, 194 | Hostname: backend.Hostname, 195 | Scheme: *gatewayv1.Backend_Scheme(scheme).Enum(), 196 | ExternalUrl: *backend.ExternalUrl, 197 | IsEnabled: *backend.IsEnabled, 198 | UptimeSchedule: *backend.UptimeSchedule, 199 | ClusterLoad: *backend.ClusterLoad, 200 | ThresholdClusterLoad: *backend.ThresholdClusterLoad, 201 | StatsUpdatedAt: *backend.StatsUpdatedAt, 202 | IsHealthy: *backend.IsHealthy, 203 | } 204 | 205 | return &response, nil 206 | } 207 | -------------------------------------------------------------------------------- /internal/gatewayserver/backendApi/validation.go: -------------------------------------------------------------------------------- 1 | package backendapi 2 | 3 | // import ( 4 | // validation "github.com/go-ozzo/ozzo-validation/v4" 5 | // ) 6 | 7 | // func (cp *CreateParams) Validate() error { 8 | // err := validation.ValidateStruct(cp, 9 | // // id, required, length non zero 10 | // validation.Field(&cp.ID, validation.Required, validation.RuneLength(1, 50)), 11 | 12 | // // Hostname, required, string, length 1-50 13 | // validation.Field(&cp.Hostname, validation.Required, validation.RuneLength(1, 50)), 14 | 15 | // // Scheme, required, string, Union(http, https) 16 | // validation.Field(&cp.Scheme, validation.Required, validation.In("http", "https")), 17 | 18 | // // // last_name, required, string, length 1-30 19 | // // validation.Field(&cp.LastName, validation.Required, validation.RuneLength(1, 30)), 20 | // ) 21 | 22 | // return err 23 | // // if err == nil { 24 | // // return nil 25 | // // } 26 | 27 | // // publicErr := errorclass.ErrValidationFailure.New(""). 28 | // // Wrap(err). 29 | // // WithPublic(&errors.Public{ 30 | // // Description: err.Error(), 31 | // // }) 32 | 33 | // // return publicErr 34 | // } 35 | -------------------------------------------------------------------------------- /internal/gatewayserver/database/dbRepo/dbRepo.go: -------------------------------------------------------------------------------- 1 | package dbRepo 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/razorpay/trino-gateway/pkg/spine" 7 | "github.com/razorpay/trino-gateway/pkg/spine/db" 8 | ) 9 | 10 | type DbRepo struct { 11 | spine.Repo 12 | } 13 | 14 | type IDbRepo interface { 15 | Create(ctx context.Context, receiver spine.IModel) error 16 | FindByID(ctx context.Context, receiver spine.IModel, id string) error 17 | FindWithConditionByIDs(ctx context.Context, receivers interface{}, condition map[string]interface{}, ids []string) error 18 | FindMany(ctx context.Context, receivers interface{}, condition map[string]interface{}) error 19 | Delete(ctx context.Context, receiver spine.IModel) error 20 | Update(ctx context.Context, receiver spine.IModel, attrList ...string) error 21 | Preload(ctx context.Context, query string, args ...interface{}) *spine.Repo 22 | ClearAssociations(ctx context.Context, receiver spine.IModel, name string) error 23 | ReplaceAssociations(ctx context.Context, receiver spine.IModel, name string, ass interface{}) error 24 | } 25 | 26 | func NewDbRepo(db *db.DB) IDbRepo { 27 | return &DbRepo{ 28 | Repo: spine.Repo{ 29 | Db: db, 30 | }, 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /internal/gatewayserver/database/migrations/20210805195304_bootstrap.go: -------------------------------------------------------------------------------- 1 | package migration 2 | 3 | import ( 4 | "database/sql" 5 | 6 | "github.com/pressly/goose/v3" 7 | ) 8 | 9 | func init() { 10 | goose.AddMigration(Up20210805195304, Down20210805195304) 11 | } 12 | 13 | func Up20210805195304(tx *sql.Tx) error { 14 | var err error 15 | _, err = tx.Exec(` 16 | CREATE TABLE backends ( 17 | id VARCHAR(255) NOT NULL, 18 | hostname VARCHAR(255) NOT NULL, 19 | scheme ENUM('http', 'https') DEFAULT 'http', 20 | external_url VARCHAR(255) NOT NULL, 21 | is_enabled bool DEFAULT FALSE, 22 | uptime_schedule VARCHAR(255) DEFAULT '* * * * *', 23 | cluster_load INT DEFAULT 0, 24 | threshold_cluster_load INT DEFAULT 0, 25 | stats_updated_at INT(11) NULL, 26 | created_at INT(11) NOT NULL, 27 | updated_at INT(11) NOT NULL, 28 | PRIMARY KEY (id), 29 | KEY users_created_at_index (created_at), 30 | KEY users_updated_at_index (updated_at) 31 | );`) 32 | if err != nil { 33 | return err 34 | } 35 | 36 | // groups is a keyword in mysql, so we use groups_ 37 | _, err = tx.Exec("CREATE TABLE `groups_` (" + 38 | `id VARCHAR(255) NOT NULL, 39 | strategy ENUM('random', 'round_robin', 'least_load') DEFAULT 'random', 40 | is_enabled bool DEFAULT FALSE, 41 | last_routed_backend VARCHAR(255), 42 | created_at INT(11) NOT NULL, 43 | updated_at INT(11) NOT NULL, 44 | PRIMARY KEY (id), 45 | KEY users_created_at_index (created_at), 46 | KEY users_updated_at_index (updated_at) 47 | );`) 48 | if err != nil { 49 | return err 50 | } 51 | 52 | _, err = tx.Exec(`CREATE TABLE group_backends_mappings ( 53 | id int AUTO_INCREMENT, 54 | group_id varchar(255), 55 | backend_id varchar(255), 56 | created_at int(11), 57 | updated_at int(11), 58 | PRIMARY KEY (id), 59 | UNIQUE KEY (group_id, backend_id), 60 | KEY users_created_at_index (created_at), 61 | KEY users_updated_at_index (updated_at) 62 | );`) 63 | if err != nil { 64 | return err 65 | } 66 | 67 | _, err = tx.Exec(`CREATE TABLE policies ( 68 | id varchar(255), 69 | rule_type ENUM ('header_client_tags', 'header_connection_properties', 'header_client_host', 'listening_port'), 70 | rule_value varchar(255), 71 | group_id varchar(255), 72 | fallback_group_id varchar(255), 73 | is_enabled bool, 74 | created_at int(11), 75 | updated_at int(11), 76 | PRIMARY KEY (id), 77 | KEY users_created_at_index (created_at), 78 | KEY users_updated_at_index (updated_at) 79 | );`) 80 | if err != nil { 81 | return err 82 | } 83 | 84 | _, err = tx.Exec(`CREATE TABLE queries ( 85 | id varchar(255), 86 | text varchar(255), 87 | client_ip varchar(255), 88 | group_id varchar(255) NULL, 89 | backend_id varchar(255) NULL, 90 | username varchar(255), 91 | server_host varchar(255), 92 | submitted_at int(11), 93 | created_at int(11), 94 | updated_at int(11), 95 | PRIMARY KEY (id), 96 | KEY users_created_at_index (created_at), 97 | KEY users_updated_at_index (updated_at) 98 | );`) 99 | if err != nil { 100 | return err 101 | } 102 | 103 | _, err = tx.Exec(`ALTER TABLE group_backends_mappings ADD FOREIGN KEY (group_id) REFERENCES groups_ (id);`) 104 | if err != nil { 105 | return err 106 | } 107 | 108 | _, err = tx.Exec(`ALTER TABLE group_backends_mappings ADD FOREIGN KEY (backend_id) REFERENCES backends (id) ON DELETE CASCADE ON UPDATE CASCADE;`) 109 | if err != nil { 110 | return err 111 | } 112 | 113 | _, err = tx.Exec(`ALTER TABLE policies ADD FOREIGN KEY (group_id) REFERENCES groups_ (id);`) 114 | if err != nil { 115 | return err 116 | } 117 | 118 | _, err = tx.Exec(`ALTER TABLE policies ADD FOREIGN KEY (fallback_group_id) REFERENCES groups_ (id);`) 119 | if err != nil { 120 | return err 121 | } 122 | 123 | _, err = tx.Exec(`ALTER TABLE queries ADD FOREIGN KEY (group_id) REFERENCES groups_ (id);`) 124 | if err != nil { 125 | return err 126 | } 127 | 128 | _, err = tx.Exec(`ALTER TABLE queries ADD FOREIGN KEY (backend_id) REFERENCES backends (id);`) 129 | if err != nil { 130 | return err 131 | } 132 | return err 133 | } 134 | 135 | func Down20210805195304(tx *sql.Tx) error { 136 | var err error 137 | _, err = tx.Exec(`DROP TABLE IF EXISTS queries;`) 138 | if err != nil { 139 | return err 140 | } 141 | _, err = tx.Exec(`DROP TABLE IF EXISTS policies;`) 142 | if err != nil { 143 | return err 144 | } 145 | _, err = tx.Exec(`DROP TABLE IF EXISTS group_backends_mappings;`) 146 | if err != nil { 147 | return err 148 | } 149 | _, err = tx.Exec(`DROP TABLE IF EXISTS backends;`) 150 | if err != nil { 151 | return err 152 | } 153 | _, err = tx.Exec(`DROP TABLE IF EXISTS groups_;`) 154 | if err != nil { 155 | return err 156 | } 157 | return err 158 | } 159 | -------------------------------------------------------------------------------- /internal/gatewayserver/database/migrations/20211203205304_alter_backends.go: -------------------------------------------------------------------------------- 1 | package migration 2 | 3 | import ( 4 | "database/sql" 5 | 6 | "github.com/pressly/goose/v3" 7 | ) 8 | 9 | func init() { 10 | goose.AddMigration(Up20211203205304, Down20211203205304) 11 | } 12 | 13 | func Up20211203205304(tx *sql.Tx) error { 14 | var err error 15 | 16 | _, err = tx.Exec("ALTER TABLE `backends` ADD COLUMN `is_healthy` BOOL DEFAULT false;") 17 | if err != nil { 18 | return err 19 | } 20 | return err 21 | } 22 | 23 | func Down20211203205304(tx *sql.Tx) error { 24 | var err error 25 | 26 | _, err = tx.Exec("ALTER TABLE `backends` DROP COLUMN `is_healthy`;") 27 | if err != nil { 28 | return err 29 | } 30 | return err 31 | } 32 | -------------------------------------------------------------------------------- /internal/gatewayserver/database/migrations/20220107205304_increase_query_text_size.go: -------------------------------------------------------------------------------- 1 | package migration 2 | 3 | import ( 4 | "database/sql" 5 | 6 | "github.com/pressly/goose/v3" 7 | ) 8 | 9 | func init() { 10 | goose.AddMigration(Up20220107205304, Down20220107205304) 11 | } 12 | 13 | func Up20220107205304(tx *sql.Tx) error { 14 | var err error 15 | 16 | _, err = tx.Exec("ALTER TABLE `queries` MODIFY COLUMN `text` VARCHAR(500);") 17 | if err != nil { 18 | return err 19 | } 20 | return err 21 | } 22 | 23 | func Down20220107205304(tx *sql.Tx) error { 24 | var err error 25 | 26 | _, err = tx.Exec("ALTER TABLE `queries` MODIFY COLUMN `text` VARCHAR(255);") 27 | if err != nil { 28 | return err 29 | } 30 | return err 31 | } 32 | -------------------------------------------------------------------------------- /internal/gatewayserver/database/migrations/20240524205304_add_auth_delegation.go: -------------------------------------------------------------------------------- 1 | package migration 2 | 3 | import ( 4 | "database/sql" 5 | 6 | "github.com/pressly/goose/v3" 7 | ) 8 | 9 | func init() { 10 | goose.AddMigration(Up20240524205304, Down20240524205304) 11 | } 12 | 13 | func Up20240524205304(tx *sql.Tx) error { 14 | var err error 15 | 16 | _, err = tx.Exec("ALTER TABLE `policies` ADD COLUMN `is_auth_delegated` BOOL DEFAULT false;") 17 | if err != nil { 18 | return err 19 | } 20 | return err 21 | } 22 | 23 | func Down20240524205304(tx *sql.Tx) error { 24 | var err error 25 | 26 | _, err = tx.Exec("ALTER TABLE `policies` DROP COLUMN `is_auth_delegated`;") 27 | if err != nil { 28 | return err 29 | } 30 | return err 31 | } 32 | -------------------------------------------------------------------------------- /internal/gatewayserver/database/migrations/20240525205304_add_set_source.go: -------------------------------------------------------------------------------- 1 | package migration 2 | 3 | import ( 4 | "database/sql" 5 | 6 | "github.com/pressly/goose/v3" 7 | ) 8 | 9 | func init() { 10 | goose.AddMigration(Up20240525205304, Down20240525205304) 11 | } 12 | 13 | func Up20240525205304(tx *sql.Tx) error { 14 | var err error 15 | 16 | _, err = tx.Exec("ALTER TABLE `policies` ADD COLUMN `set_request_source` VARCHAR(255) DEFAULT '';") 17 | if err != nil { 18 | return err 19 | } 20 | return err 21 | } 22 | 23 | func Down20240525205304(tx *sql.Tx) error { 24 | var err error 25 | 26 | _, err = tx.Exec("ALTER TABLE `policies` DROP COLUMN `set_request_source`;") 27 | if err != nil { 28 | return err 29 | } 30 | return err 31 | } 32 | -------------------------------------------------------------------------------- /internal/gatewayserver/groupApi/server.go: -------------------------------------------------------------------------------- 1 | package groupapi 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "strings" 8 | 9 | "github.com/razorpay/trino-gateway/internal/gatewayserver/models" 10 | "github.com/razorpay/trino-gateway/internal/provider" 11 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway" 12 | _ "github.com/twitchtv/twirp" 13 | ) 14 | 15 | // Server has methods implementing of server rpc. 16 | type Server struct { 17 | core ICore 18 | } 19 | 20 | // NewServer returns a server. 21 | func NewServer(core ICore) *Server { 22 | return &Server{ 23 | core: core, 24 | } 25 | } 26 | 27 | // Create creates a new group 28 | func (s *Server) CreateOrUpdateGroup(ctx context.Context, req *gatewayv1.Group) (*gatewayv1.Empty, error) { 29 | // defer span.Finish() 30 | 31 | provider.Logger(ctx).Debugw("CreateOrUpdateGroup", map[string]interface{}{ 32 | "request": req.String(), 33 | }) 34 | 35 | createParams := GroupCreateParams{ 36 | ID: req.GetId(), 37 | Strategy: req.GetStrategy().Enum().String(), 38 | Backends: req.GetBackends(), 39 | IsEnabled: req.GetIsEnabled(), 40 | LastRoutedBackend: req.GetLastRoutedBackend(), 41 | } 42 | 43 | err := s.core.CreateOrUpdateGroup(ctx, &createParams) 44 | if err != nil { 45 | return nil, err 46 | } 47 | 48 | return &gatewayv1.Empty{}, nil 49 | } 50 | 51 | // Get retrieves a single group record 52 | func (s *Server) GetGroup(ctx context.Context, req *gatewayv1.GroupGetRequest) (*gatewayv1.GroupGetResponse, error) { 53 | provider.Logger(ctx).Debugw("GetGroup", map[string]interface{}{ 54 | "request": req.String(), 55 | }) 56 | 57 | group, err := s.core.GetGroup(ctx, req.GetId()) 58 | if err != nil { 59 | return nil, err 60 | } 61 | groupProto, err := toGroupResponseProto(group) 62 | if err != nil { 63 | return nil, err 64 | } 65 | return &gatewayv1.GroupGetResponse{Group: groupProto}, nil 66 | } 67 | 68 | // List fetches a list of filtered group records 69 | func (s *Server) ListAllGroups(ctx context.Context, req *gatewayv1.Empty) (*gatewayv1.GroupListAllResponse, error) { 70 | provider.Logger(ctx).Debugw("ListAllGroups", map[string]interface{}{ 71 | "request": req.String(), 72 | }) 73 | groups, err := s.core.GetAllGroups(ctx) 74 | if err != nil { 75 | return nil, err 76 | } 77 | 78 | groupsProto := make([]*gatewayv1.Group, len(groups)) 79 | for i, groupModel := range groups { 80 | group, err := toGroupResponseProto(&groupModel) 81 | if err != nil { 82 | return nil, err 83 | } 84 | groupsProto[i] = group 85 | } 86 | 87 | response := gatewayv1.GroupListAllResponse{ 88 | Items: groupsProto, 89 | } 90 | 91 | return &response, nil 92 | } 93 | 94 | // Approve marks a groups status to approved 95 | 96 | func (s *Server) EnableGroup(ctx context.Context, req *gatewayv1.GroupEnableRequest) (*gatewayv1.Empty, error) { 97 | provider.Logger(ctx).Debugw("EnableGroup", map[string]interface{}{ 98 | "request": req.String(), 99 | }) 100 | err := s.core.EnableGroup(ctx, req.GetId()) 101 | if err != nil { 102 | return nil, err 103 | } 104 | 105 | return &gatewayv1.Empty{}, nil 106 | } 107 | 108 | func (s *Server) DisableGroup(ctx context.Context, req *gatewayv1.GroupDisableRequest) (*gatewayv1.Empty, error) { 109 | provider.Logger(ctx).Debugw("DisableGroup", map[string]interface{}{ 110 | "request": req.String(), 111 | }) 112 | err := s.core.DisableGroup(ctx, req.GetId()) 113 | if err != nil { 114 | return nil, err 115 | } 116 | 117 | return &gatewayv1.Empty{}, nil 118 | } 119 | 120 | // Delete deletes a group, soft-delete 121 | func (s *Server) DeleteGroup(ctx context.Context, req *gatewayv1.GroupDeleteRequest) (*gatewayv1.Empty, error) { 122 | provider.Logger(ctx).Debugw("DeleteGroup", map[string]interface{}{ 123 | "request": req.String(), 124 | }) 125 | err := s.core.DeleteGroup(ctx, req.GetId()) 126 | if err != nil { 127 | return nil, err 128 | } 129 | 130 | return &gatewayv1.Empty{}, nil 131 | } 132 | 133 | func toGroupResponseProto(group *models.Group) (*gatewayv1.Group, error) { 134 | if group == nil { 135 | return &gatewayv1.Group{}, nil 136 | } 137 | strategy, ok := gatewayv1.Group_RoutingStrategy_value[strings.ToUpper(*group.Strategy)] 138 | if !ok { 139 | return nil, errors.New(fmt.Sprint("error encoding response: invalid strategy ", *group.Strategy)) 140 | } 141 | var backends []string 142 | for _, backend := range group.GroupBackendsMappings { 143 | backends = append(backends, backend.BackendId) 144 | } 145 | response := gatewayv1.Group{ 146 | Id: group.ID, 147 | Strategy: *gatewayv1.Group_RoutingStrategy(strategy).Enum(), 148 | Backends: backends, 149 | IsEnabled: *group.IsEnabled, 150 | LastRoutedBackend: *group.LastRoutedBackend, 151 | } 152 | 153 | return &response, nil 154 | } 155 | 156 | func (s *Server) EvaluateBackendForGroups(ctx context.Context, req *gatewayv1.EvaluateBackendRequest) (*gatewayv1.EvaluateBackendResponse, error) { 157 | provider.Logger(ctx).Debugw("EvaluateBackendForGroups", map[string]interface{}{ 158 | "request": req.String(), 159 | }) 160 | 161 | backend_id, group_id, err := s.core.EvaluateBackendForGroups(ctx, req.GetGroupIds()) 162 | if err != nil { 163 | return nil, err 164 | 165 | } 166 | return &gatewayv1.EvaluateBackendResponse{BackendId: backend_id, GroupId: group_id}, nil 167 | } 168 | -------------------------------------------------------------------------------- /internal/gatewayserver/groupApi/validation.go: -------------------------------------------------------------------------------- 1 | package groupapi 2 | 3 | // import ( 4 | // validation "github.com/go-ozzo/ozzo-validation/v4" 5 | // ) 6 | 7 | // func (cp *CreateParams) Validate() error { 8 | // err := validation.ValidateStruct(cp, 9 | // // id, required, length non zero 10 | // validation.Field(&cp.ID, validation.Required, validation.RuneLength(1, 50)), 11 | 12 | // // Hostname, required, string, length 1-50 13 | // validation.Field(&cp.Hostname, validation.Required, validation.RuneLength(1, 50)), 14 | 15 | // // Scheme, required, string, Union(http, https) 16 | // validation.Field(&cp.Scheme, validation.Required, validation.In("http", "https")), 17 | 18 | // // // last_name, required, string, length 1-30 19 | // // validation.Field(&cp.LastName, validation.Required, validation.RuneLength(1, 30)), 20 | // ) 21 | 22 | // return err 23 | // // if err == nil { 24 | // // return nil 25 | // // } 26 | 27 | // // publicErr := errorclass.ErrValidationFailure.New(""). 28 | // // Wrap(err). 29 | // // WithPublic(&errors.Public{ 30 | // // Description: err.Error(), 31 | // // }) 32 | 33 | // // return publicErr 34 | // } 35 | -------------------------------------------------------------------------------- /internal/gatewayserver/healthApi/core.go: -------------------------------------------------------------------------------- 1 | package healthapi 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | ) 8 | 9 | // Core holds business logic and/or orchestrator of other things in the package. 10 | type Core struct { 11 | isHealthy bool 12 | mutex sync.Mutex 13 | } 14 | 15 | // NewCore creates Core. 16 | func NewCore() *Core { 17 | return &Core{ 18 | isHealthy: true, 19 | } 20 | } 21 | 22 | // RunHealthCheck runs various server checks and returns true if all individual components are working fine. 23 | // Todo: Fix server check response per https://tools.ietf.org/id/draft-inadarei-api-health-check-01.html :) 24 | func (c *Core) RunHealthCheck(ctx context.Context) (bool, error) { 25 | if !c.isHealthy { 26 | return false, fmt.Errorf("server marked unhealthy") 27 | } 28 | 29 | var err error 30 | 31 | // Checks the DB connection exists and is alive by executing a select query. 32 | // err = c.repo.Alive(ctx) 33 | if err != nil { 34 | // logger.Ctx(ctx).Errorw("failed to execute select query on db connection", "error", err) 35 | 36 | // Check fallback routing group exists & has atleast 1 active backend - IF NO set state to INVALID 37 | // 38 | } 39 | isDbAlive := err == nil 40 | 41 | return isDbAlive, err 42 | } 43 | 44 | // MarkUnhealthy marks the server as unhealthy for health check to return negative 45 | func (c *Core) MarkUnhealthy() { 46 | c.mutex.Lock() 47 | c.isHealthy = false 48 | c.mutex.Unlock() 49 | } 50 | -------------------------------------------------------------------------------- /internal/gatewayserver/healthApi/server.go: -------------------------------------------------------------------------------- 1 | package healthapi 2 | 3 | import ( 4 | "context" 5 | 6 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway" 7 | "github.com/twitchtv/twirp" 8 | ) 9 | 10 | // Server has methods implementing of server rpc. 11 | type Server struct { 12 | core *Core 13 | } 14 | 15 | // NewServer returns a server. 16 | func NewServer(core *Core) *Server { 17 | return &Server{ 18 | core: core, 19 | } 20 | } 21 | 22 | // Check returns service's serving status. 23 | func (s *Server) Check(ctx context.Context, req *gatewayv1.HealthCheckRequest) (*gatewayv1.HealthCheckResponse, error) { 24 | var status gatewayv1.HealthCheckResponse_ServingStatus 25 | ok, err := s.core.RunHealthCheck(ctx) 26 | if !ok { 27 | status = gatewayv1.HealthCheckResponse_SERVING_STATUS_NOT_SERVING 28 | return &gatewayv1.HealthCheckResponse{ServingStatus: status}, twirp.NewError(twirp.Unavailable, err.Error()) 29 | } 30 | status = gatewayv1.HealthCheckResponse_SERVING_STATUS_SERVING 31 | return &gatewayv1.HealthCheckResponse{ServingStatus: status}, nil 32 | } 33 | -------------------------------------------------------------------------------- /internal/gatewayserver/hooks/auth.go: -------------------------------------------------------------------------------- 1 | package hooks 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "strings" 8 | 9 | "github.com/twitchtv/twirp" 10 | 11 | "github.com/razorpay/trino-gateway/internal/boot" 12 | ) 13 | 14 | type contextkey int 15 | 16 | const ( 17 | authTokenCtxKey contextkey = iota 18 | authUrlPathCtxKey 19 | ) 20 | 21 | func Auth() *twirp.ServerHooks { 22 | hooks := &twirp.ServerHooks{} 23 | 24 | hooks.RequestReceived = func(ctx context.Context) (context.Context, error) { 25 | m, _ := ctx.Value(authUrlPathCtxKey).(string) 26 | if strings.Contains(m, "/Get") || strings.Contains(m, "/List") { 27 | return ctx, nil 28 | } 29 | 30 | token, _ := ctx.Value(authTokenCtxKey).(string) 31 | 32 | if token == "" { 33 | return ctx, twirp.NewError( 34 | twirp.Unauthenticated, 35 | fmt.Sprint( 36 | "empty/undefined apiToken in header: ", 37 | boot.Config.Auth.TokenHeaderKey), 38 | ) 39 | } 40 | 41 | if boot.Config.Auth.Token == token { 42 | return ctx, nil 43 | } 44 | 45 | return ctx, twirp.NewError(twirp.Unauthenticated, "invalid apiToken for authentication") 46 | } 47 | 48 | return hooks 49 | } 50 | 51 | func WithAuth(h http.Handler) http.Handler { 52 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 53 | ctx := r.Context() 54 | token := r.Header.Get(boot.Config.Auth.TokenHeaderKey) 55 | urlPath := r.URL.Path 56 | 57 | ctx = context.WithValue(ctx, authTokenCtxKey, token) 58 | ctx = context.WithValue(ctx, authUrlPathCtxKey, urlPath) 59 | 60 | r = r.WithContext(ctx) 61 | 62 | h.ServeHTTP(w, r) 63 | }) 64 | } 65 | -------------------------------------------------------------------------------- /internal/gatewayserver/hooks/ctx.go: -------------------------------------------------------------------------------- 1 | package hooks 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/razorpay/trino-gateway/pkg/logger" 7 | "github.com/twitchtv/twirp" 8 | 9 | "github.com/razorpay/trino-gateway/internal/boot" 10 | "github.com/razorpay/trino-gateway/internal/provider" 11 | ) 12 | 13 | // Ctx returns function which sets context with core service 14 | // information and puts contextual logger into same for later use. 15 | func Ctx() *twirp.ServerHooks { 16 | hooks := &twirp.ServerHooks{} 17 | 18 | hooks.RequestRouted = func(ctx context.Context) (context.Context, error) { 19 | ctx = boot.WithRequestID(ctx, "") 20 | 21 | // Adds more contextual info in above logger. 22 | // Todo: Check why method, service and package names are not known in this hook. 23 | reqMethod, _ := twirp.MethodName(ctx) 24 | reqService, _ := twirp.ServiceName(ctx) 25 | reqPackage, _ := twirp.PackageName(ctx) 26 | req := map[string]interface{}{ 27 | "reqId": boot.GetRequestID(ctx), 28 | // TODO: set auth user in auth Hook 29 | // "reqUser": ctx.Value("authUserCtxKey"), 30 | "reqMethod": reqMethod, 31 | "reqService": reqService, 32 | "reqPackage": reqPackage, 33 | } 34 | 35 | return context.WithValue(ctx, logger.LoggerCtxKey, provider.Logger(ctx).WithFields(req)), nil 36 | } 37 | 38 | return hooks 39 | } 40 | -------------------------------------------------------------------------------- /internal/gatewayserver/hooks/metric.go: -------------------------------------------------------------------------------- 1 | package hooks 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/razorpay/trino-gateway/internal/gatewayserver/metrics" 9 | "github.com/twitchtv/twirp" 10 | ) 11 | 12 | var reqStartTsCtxKey = new(int) 13 | 14 | // Metric returns function which puts unique request id into context. 15 | func Metric() *twirp.ServerHooks { 16 | hooks := &twirp.ServerHooks{} 17 | 18 | // RequestReceived: 19 | hooks.RequestReceived = func(ctx context.Context) (context.Context, error) { 20 | ctx = markRequestStart(ctx) 21 | 22 | return ctx, nil 23 | } 24 | 25 | // RequestRouted: 26 | hooks.RequestRouted = func(ctx context.Context) (context.Context, error) { 27 | pkg, _ := twirp.PackageName(ctx) 28 | service, _ := twirp.ServiceName(ctx) 29 | method, _ := twirp.MethodName(ctx) 30 | 31 | metrics.RequestsReceivedTotal. 32 | WithLabelValues(pkg, service, method). 33 | Inc() 34 | 35 | return ctx, nil 36 | } 37 | 38 | // ResponseSent: 39 | hooks.ResponseSent = func(ctx context.Context) { 40 | start, _ := getRequestStart(ctx) 41 | pkg, _ := twirp.PackageName(ctx) 42 | service, _ := twirp.ServiceName(ctx) 43 | method, _ := twirp.MethodName(ctx) 44 | statusCode, _ := twirp.StatusCode(ctx) 45 | 46 | duration := float64(time.Now().Sub(start).Milliseconds()) 47 | 48 | metrics.ResponsesSentTotal.WithLabelValues( 49 | pkg, service, method, 50 | fmt.Sprintf("%v", statusCode), 51 | ).Inc() 52 | 53 | metrics.ResponseDurations.WithLabelValues( 54 | pkg, service, method, 55 | fmt.Sprintf("%v", statusCode), 56 | ).Observe(duration) 57 | } 58 | 59 | return hooks 60 | } 61 | 62 | func markRequestStart(ctx context.Context) context.Context { 63 | return context.WithValue(ctx, reqStartTsCtxKey, time.Now()) 64 | } 65 | 66 | func getRequestStart(ctx context.Context) (time.Time, bool) { 67 | t, ok := ctx.Value(reqStartTsCtxKey).(time.Time) 68 | return t, ok 69 | } 70 | -------------------------------------------------------------------------------- /internal/gatewayserver/hooks/requestid.go: -------------------------------------------------------------------------------- 1 | package hooks 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "github.com/rs/xid" 8 | "github.com/twitchtv/twirp" 9 | 10 | "github.com/razorpay/trino-gateway/internal/constants/contextkeys" 11 | ) 12 | 13 | const ( 14 | requestIDHttpHeaderKey = "X-Request-ID" 15 | ) 16 | 17 | // RequestID returns function which puts unique request id into context. 18 | func RequestID() *twirp.ServerHooks { 19 | hooks := &twirp.ServerHooks{} 20 | 21 | hooks.RequestRouted = func(ctx context.Context) (context.Context, error) { 22 | // var err error 23 | requestID, _ := ctx.Value(contextkeys.RequestID).(string) 24 | if requestID == "" { 25 | requestID = xid.New().String() 26 | } 27 | ctx = context.WithValue(ctx, contextkeys.RequestID, requestID) 28 | return ctx, nil 29 | } 30 | 31 | return hooks 32 | } 33 | 34 | // WithRequestID is a http handler which puts specific http request header into 35 | // request context which is made available in twirp hooks. 36 | // Refer: https://twitchtv.github.io/twirp/docs/headers.html 37 | func WithRequestID(h http.Handler) http.Handler { 38 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 39 | ctx := setCtxRequestId(r.Context(), r) 40 | r = r.WithContext(ctx) 41 | h.ServeHTTP(w, r) 42 | }) 43 | } 44 | 45 | func setCtxRequestId(ctx context.Context, r *http.Request) context.Context { 46 | requestID := r.Header.Get(requestIDHttpHeaderKey) 47 | return context.WithValue(ctx, contextkeys.RequestID, requestID) 48 | } 49 | -------------------------------------------------------------------------------- /internal/gatewayserver/metrics/metrics.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "github.com/prometheus/client_golang/prometheus" 5 | "github.com/prometheus/client_golang/prometheus/promauto" 6 | "github.com/razorpay/trino-gateway/internal/boot" 7 | ) 8 | 9 | var ( 10 | RequestsReceivedTotal *prometheus.CounterVec 11 | ResponsesSentTotal *prometheus.CounterVec 12 | ResponseDurations *prometheus.HistogramVec 13 | FallbackGroupInvoked *prometheus.CounterVec 14 | ) 15 | 16 | func init() { 17 | env := boot.Config.App.Env 18 | RequestsReceivedTotal = promauto.NewCounterVec( 19 | prometheus.CounterOpts{ 20 | Name: "trino_gateway_http_requests_total", 21 | Help: "Number of HTTP requests received.", 22 | }, 23 | []string{"env", "package", "server", "method"}, 24 | ).MustCurryWith(prometheus.Labels{"env": env}) 25 | 26 | ResponsesSentTotal = promauto.NewCounterVec( 27 | prometheus.CounterOpts{ 28 | Name: "trino_gateway_http_responses_total", 29 | Help: "Number of HTTP responses sent.", 30 | }, 31 | []string{"env", "package", "server", "method", "code"}, 32 | ).MustCurryWith(prometheus.Labels{"env": env}) 33 | 34 | ResponseDurations = promauto.NewHistogramVec( 35 | prometheus.HistogramOpts{ 36 | Name: "trino_gateway_http_durations_ms_histogram", 37 | Help: "HTTP latency distributions histogram.", 38 | Buckets: []float64{2, 5, 10, 15, 25, 40, 60, 85, 120, 150, 200, 300}, 39 | }, 40 | []string{"env", "package", "server", "method", "code"}, 41 | ).MustCurryWith(prometheus.Labels{"env": env}).(*prometheus.HistogramVec) 42 | 43 | FallbackGroupInvoked = promauto.NewCounterVec( 44 | prometheus.CounterOpts{ 45 | Name: "trino_gateway_fallback_group_invoked_total", 46 | Help: "Number of requests where fallback group routing was invoked", 47 | }, 48 | []string{"env"}, 49 | ).MustCurryWith(prometheus.Labels{"env": env}) 50 | } 51 | -------------------------------------------------------------------------------- /internal/gatewayserver/models/backend.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | validation "github.com/go-ozzo/ozzo-validation" 5 | "github.com/razorpay/trino-gateway/pkg/spine" 6 | ) 7 | 8 | // backend model struct definition 9 | type Backend struct { 10 | spine.Model 11 | Hostname string `json:"hostname"` 12 | Scheme string `json:"scheme"` 13 | ExternalUrl *string `json:"external_url"` 14 | IsEnabled *bool `json:"is_enabled"` 15 | IsHealthy *bool `json:"is_healthy"` 16 | UptimeSchedule *string `json:"uptime_schedule" gorm:"default:'* * * * *';"` 17 | ClusterLoad *int32 `json:"cluster_load"` 18 | ThresholdClusterLoad *int32 `json:"threshold_cluster_load"` 19 | StatsUpdatedAt *int64 `json:"stats_updated_at"` 20 | } 21 | 22 | func (u *Backend) TableName() string { 23 | return "backends" 24 | } 25 | 26 | func (u *Backend) EntityName() string { 27 | return "backend" 28 | } 29 | 30 | func (u *Backend) SetDefaults() error { 31 | return nil 32 | } 33 | 34 | func (u *Backend) Validate() error { 35 | // fmt.Printf("{%v}\n", *u.StatsUpdatedAt) 36 | err := validation.ValidateStruct(u, 37 | // id, required, length non zero 38 | validation.Field(&u.ID, validation.Required, validation.RuneLength(1, 50)), 39 | 40 | // url, required, string, length 1-30 41 | validation.Field(&u.Hostname, validation.Required, validation.RuneLength(1, 255)), 42 | 43 | // Scheme, required, string, Union(http, https) 44 | validation.Field(&u.Scheme, validation.Required, validation.In("http", "https")), 45 | 46 | // first_name, required, string, length 1-30 47 | validation.Field(&u.ExternalUrl, validation.Required, validation.RuneLength(1, 255)), 48 | 49 | // validation.Field(&u.StatsUpdatedAt, validation.By(datatype.IsTimestamp)), 50 | 51 | // status, required, string 52 | // validation.Field(&u.IsEnabled, validation.Required, validation.In(true, false)), 53 | ) 54 | 55 | return err 56 | } 57 | -------------------------------------------------------------------------------- /internal/gatewayserver/models/group.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import "github.com/razorpay/trino-gateway/pkg/spine" 4 | 5 | // group model struct definition 6 | type Group struct { 7 | spine.Model 8 | Strategy *string `json:"strategy"` 9 | IsEnabled *bool `json:"is_enabled" sql:"DEFAULT:true"` 10 | LastRoutedBackend *string `json:"last_routed_backend"` 11 | GroupBackendsMappings []GroupBackendsMapping `gorm:"foreignKey:GroupId;references:ID"` 12 | } 13 | 14 | func (u *Group) TableName() string { 15 | return "groups_" 16 | } 17 | 18 | func (u *Group) EntityName() string { 19 | return "group" 20 | } 21 | 22 | func (u *Group) SetDefaults() error { 23 | return nil 24 | } 25 | 26 | func (u *Group) Validate() error { 27 | return nil 28 | } 29 | 30 | type GroupBackendsMapping struct { 31 | spine.Model 32 | ID *int32 `json:"id" sql:"DEFAULT:NULL"` 33 | GroupId string `json:"group_id" gorm:"primaryKey"` 34 | BackendId string `json:"backend_id" gorm:"primaryKey"` 35 | } 36 | 37 | func (u *GroupBackendsMapping) TableName() string { 38 | return "group_backends_mappings" 39 | } 40 | 41 | func (u *GroupBackendsMapping) EntityName() string { 42 | return "group_backends_mappings" 43 | } 44 | 45 | func (u *GroupBackendsMapping) SetDefaults() error { 46 | return nil 47 | } 48 | 49 | func (u *GroupBackendsMapping) Validate() error { 50 | return nil 51 | } 52 | -------------------------------------------------------------------------------- /internal/gatewayserver/models/policy.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import "github.com/razorpay/trino-gateway/pkg/spine" 4 | 5 | // policy model struct definition 6 | type Policy struct { 7 | spine.Model 8 | RuleType string `json:"rule_type"` 9 | RuleValue string `json:"rule_value"` 10 | GroupId string `json:"group_id"` 11 | FallbackGroupId *string `json:"fallback_group_id"` 12 | IsEnabled *bool `json:"is_enabled" sql:"DEFAULT:true"` 13 | IsAuthDelegated *bool `json:"is_auth_delegated" sql:"DEFAULT:false"` 14 | SetRequestSource *string `json:"set_request_source"` 15 | } 16 | 17 | func (u *Policy) TableName() string { 18 | return "policies" 19 | } 20 | 21 | func (u *Policy) EntityName() string { 22 | return "policy" 23 | } 24 | 25 | func (u *Policy) SetDefaults() error { 26 | return nil 27 | } 28 | 29 | func (u *Policy) Validate() error { 30 | return nil 31 | } 32 | -------------------------------------------------------------------------------- /internal/gatewayserver/models/query.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import "github.com/razorpay/trino-gateway/pkg/spine" 4 | 5 | // query model struct definition 6 | type Query struct { 7 | spine.Model 8 | Text string `json:"text"` 9 | ClientIp string `json:"client_ip"` 10 | GroupId string `json:"group_id"` 11 | BackendId string `json:"backend_id"` 12 | Username string `json:"username"` 13 | SubmittedAt int64 `json:"submitted_at"` 14 | ServerHost string `json:"server_host"` 15 | } 16 | 17 | func (u *Query) TableName() string { 18 | return "queries" 19 | } 20 | 21 | func (u *Query) EntityName() string { 22 | return "query" 23 | } 24 | 25 | func (u *Query) SetDefaults() error { 26 | return nil 27 | } 28 | 29 | func (u *Query) Validate() error { 30 | return nil 31 | } 32 | -------------------------------------------------------------------------------- /internal/gatewayserver/policyApi/core_test.go: -------------------------------------------------------------------------------- 1 | package policyapi 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func Test_setIntersection(t *testing.T) { 10 | s_empty := map[string]struct{}{} 11 | s_nil := map[string]struct{}(nil) 12 | s1 := map[string]struct{}{ 13 | "a": {}, 14 | "b": {}, 15 | } 16 | s2 := map[string]struct{}{ 17 | "c": {}, 18 | "d": {}, 19 | } 20 | s3 := map[string]struct{}{ 21 | "a": {}, 22 | "d": {}, 23 | } 24 | s13 := map[string]struct{}{ 25 | "a": {}, 26 | } 27 | s23 := map[string]struct{}{ 28 | "d": {}, 29 | } 30 | 31 | // both empty sets 32 | assert.Equal(t, s_empty, setIntersection(s_empty, s_empty)) 33 | 34 | // first set empty 35 | assert.Equal(t, s_empty, setIntersection(s_empty, s1)) 36 | 37 | // second set empty 38 | assert.Equal(t, s_empty, setIntersection(s1, s_empty)) 39 | 40 | // both non-empty sets with empty intersection 41 | assert.Equal(t, s_empty, setIntersection(s1, s2)) 42 | 43 | // both non-empty sets with non-empty intersection 44 | assert.Equal(t, s13, setIntersection(s1, s3)) 45 | assert.Equal(t, s23, setIntersection(s3, s2)) 46 | // reverse order 47 | assert.Equal(t, s13, setIntersection(s3, s1)) 48 | assert.Equal(t, s23, setIntersection(s2, s3)) 49 | 50 | // one set nil 51 | assert.Equal(t, s3, setIntersection(s_nil, s3)) 52 | 53 | // both sets nil 54 | assert.Equal(t, s_nil, setIntersection(s_nil, s_nil)) 55 | 56 | // nested stuff 57 | assert.Equal(t, s13, setIntersection(setIntersection(setIntersection(s1, s_nil), s_nil), s3)) 58 | } 59 | -------------------------------------------------------------------------------- /internal/gatewayserver/policyApi/validation.go: -------------------------------------------------------------------------------- 1 | package policyapi 2 | 3 | // import ( 4 | // validation "github.com/go-ozzo/ozzo-validation/v4" 5 | // ) 6 | 7 | // func (cp *CreateParams) Validate() error { 8 | // err := validation.ValidateStruct(cp, 9 | // // id, required, length non zero 10 | // validation.Field(&cp.ID, validation.Required, validation.RuneLength(1, 50)), 11 | 12 | // // Hostname, required, string, length 1-50 13 | // validation.Field(&cp.Hostname, validation.Required, validation.RuneLength(1, 50)), 14 | 15 | // // Scheme, required, string, Union(http, https) 16 | // validation.Field(&cp.Scheme, validation.Required, validation.In("http", "https")), 17 | 18 | // // // last_name, required, string, length 1-30 19 | // // validation.Field(&cp.LastName, validation.Required, validation.RuneLength(1, 30)), 20 | // ) 21 | 22 | // return err 23 | // // if err == nil { 24 | // // return nil 25 | // // } 26 | 27 | // // publicErr := errorclass.ErrValidationFailure.New(""). 28 | // // Wrap(err). 29 | // // WithPublic(&errors.Public{ 30 | // // Description: err.Error(), 31 | // // }) 32 | 33 | // // return publicErr 34 | // } 35 | -------------------------------------------------------------------------------- /internal/gatewayserver/queryApi/core.go: -------------------------------------------------------------------------------- 1 | package queryapi 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/fatih/structs" 7 | "github.com/razorpay/trino-gateway/internal/gatewayserver/models" 8 | "github.com/razorpay/trino-gateway/internal/gatewayserver/repo" 9 | fetcherPkg "github.com/razorpay/trino-gateway/pkg/fetcher" 10 | ) 11 | 12 | var entityName string = (&models.Query{}).EntityName() 13 | 14 | type Core struct { 15 | queryRepo repo.IQueryRepo 16 | fetcher fetcherPkg.IClient 17 | } 18 | 19 | type ICore interface { 20 | CreateOrUpdateQuery(ctx context.Context, params *QueryCreateParams) error 21 | GetQuery(ctx context.Context, id string) (*models.Query, error) 22 | FindMany(ctx context.Context, params IFindManyParams) ([]models.Query, error) 23 | } 24 | 25 | func NewCore(query repo.IQueryRepo, fetcher fetcherPkg.IClient) *Core { 26 | if !fetcher.IsEntityRegistered(entityName) { 27 | fetcher.Register(entityName, &models.Query{}, &[]models.Query{}) 28 | } 29 | return &Core{ 30 | queryRepo: query, 31 | fetcher: fetcher, 32 | } 33 | } 34 | 35 | // CreateParams has attributes that are required for query.Create() 36 | type QueryCreateParams struct { 37 | ID string 38 | Text string 39 | ClientIp string 40 | BackendId string 41 | Username string 42 | GroupId string 43 | ServerHost string 44 | SubmittedAt int64 45 | } 46 | 47 | func (c *Core) CreateOrUpdateQuery(ctx context.Context, params *QueryCreateParams) error { 48 | query := models.Query{ 49 | Text: params.Text, 50 | ClientIp: params.ClientIp, 51 | BackendId: params.BackendId, 52 | Username: params.Username, 53 | GroupId: params.GroupId, 54 | ServerHost: params.ServerHost, 55 | SubmittedAt: params.SubmittedAt, 56 | } 57 | query.ID = params.ID 58 | _, exists := c.queryRepo.Find(ctx, params.ID) 59 | if exists == nil { // update 60 | return c.queryRepo.Update(ctx, &query) 61 | } else { // create 62 | return c.queryRepo.Create(ctx, &query) 63 | } 64 | } 65 | 66 | func (c *Core) GetQuery(ctx context.Context, id string) (*models.Query, error) { 67 | query, err := c.queryRepo.Find(ctx, id) 68 | return query, err 69 | } 70 | 71 | type IFindManyParams interface { 72 | GetCount() int32 73 | GetSkip() int32 74 | GetFrom() int64 75 | GetTo() int64 76 | // GetOrderBy() string 77 | 78 | // custom 79 | GetUsername() string 80 | GetBackendId() string 81 | GetGroupId() string 82 | } 83 | 84 | type Filters struct { 85 | // custom 86 | Username string `json:"username,omitempty"` 87 | BackendId string `json:"backend_id,omitempty"` 88 | GroupId string `json:"group_id,omitempty"` 89 | } 90 | 91 | func (c *Core) FindMany(ctx context.Context, params IFindManyParams) ([]models.Query, error) { 92 | conditionStr := structs.New(Filters{ 93 | Username: params.GetUsername(), 94 | BackendId: params.GetBackendId(), 95 | GroupId: params.GetGroupId(), 96 | }) 97 | // use the json tag name, so we can respect omitempty tags 98 | conditionStr.TagName = "json" 99 | conditions := conditionStr.Map() 100 | 101 | // return c.queryRepo.FindMany(ctx, conditions) 102 | 103 | pagination := fetcherPkg.Pagination{} 104 | 105 | pagination.Skip = int(params.GetSkip()) 106 | pagination.Limit = int(params.GetCount()) 107 | 108 | t := params 109 | timeRange := fetcherPkg.TimeRange{} 110 | if params.GetFrom() != int64(0) && params.GetTo() != int64(0) { 111 | timeRange.From = t.GetFrom() 112 | timeRange.To = t.GetTo() 113 | } 114 | 115 | fetchRequest := fetcherPkg.FetchMultipleRequest{ 116 | EntityName: entityName, 117 | Filter: conditions, 118 | Pagination: pagination, 119 | TimeRange: timeRange, 120 | IsTrashed: false, 121 | HasCreatedAt: true, 122 | } 123 | 124 | resp, err := c.fetcher.FetchMultiple(ctx, fetchRequest) 125 | if err != nil { 126 | return nil, err 127 | } 128 | 129 | queries := (resp.GetEntities().(map[string]interface{})[entityName]).(*[]models.Query) 130 | 131 | return *queries, nil 132 | } 133 | -------------------------------------------------------------------------------- /internal/gatewayserver/queryApi/server.go: -------------------------------------------------------------------------------- 1 | package queryapi 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/razorpay/trino-gateway/internal/gatewayserver/models" 7 | "github.com/razorpay/trino-gateway/internal/provider" 8 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway" 9 | _ "github.com/twitchtv/twirp" 10 | ) 11 | 12 | // Server has methods implementing of server rpc. 13 | type Server struct { 14 | core ICore 15 | } 16 | 17 | // NewServer returns a server. 18 | func NewServer(core ICore) *Server { 19 | return &Server{ 20 | core: core, 21 | } 22 | } 23 | 24 | func (s *Server) CreateOrUpdateQuery(ctx context.Context, req *gatewayv1.Query) (*gatewayv1.Empty, error) { 25 | provider.Logger(ctx).Debugw("CreateOrUpdateQuery", map[string]interface{}{ 26 | "request": req.String(), 27 | }) 28 | 29 | createParams := QueryCreateParams{ 30 | ID: req.GetId(), 31 | Text: req.GetText(), 32 | ClientIp: req.GetClientIp(), 33 | GroupId: req.GetGroupId(), 34 | BackendId: req.GetBackendId(), 35 | Username: req.GetUsername(), 36 | ServerHost: req.GetServerHost(), 37 | SubmittedAt: req.GetSubmittedAt(), 38 | } 39 | 40 | err := s.core.CreateOrUpdateQuery(ctx, &createParams) 41 | if err != nil { 42 | return nil, err 43 | } 44 | 45 | return &gatewayv1.Empty{}, nil 46 | } 47 | 48 | func (s *Server) GetQuery(ctx context.Context, req *gatewayv1.QueryGetRequest) (*gatewayv1.QueryGetResponse, error) { 49 | provider.Logger(ctx).Debugw("GetQuery", map[string]interface{}{ 50 | "request": req.String(), 51 | }) 52 | query, err := s.core.GetQuery(ctx, req.GetId()) 53 | if err != nil { 54 | return nil, err 55 | } 56 | queryProto, err := toQueryResponseProto(query) 57 | if err != nil { 58 | return nil, err 59 | } 60 | return &gatewayv1.QueryGetResponse{Query: queryProto}, nil 61 | } 62 | 63 | func (s *Server) ListQueries(ctx context.Context, req *gatewayv1.QueriesListRequest) (*gatewayv1.QueriesListResponse, error) { 64 | provider.Logger(ctx).Debugw("ListQueries", map[string]interface{}{ 65 | "request": req.String(), 66 | }) 67 | // TODO 68 | 69 | if err := ValidateMultiFetchRequest(ctx, req); err != nil { 70 | return nil, err 71 | } 72 | 73 | queries, err := s.core.FindMany(ctx, req) 74 | if err != nil { 75 | return nil, err 76 | } 77 | 78 | queriesProto := make([]*gatewayv1.Query, len(queries)) 79 | for i, queryModel := range queries { 80 | query, err := toQueryResponseProto(&queryModel) 81 | if err != nil { 82 | return nil, err 83 | } 84 | queriesProto[i] = query 85 | } 86 | 87 | response := gatewayv1.QueriesListResponse{ 88 | Items: queriesProto, 89 | Count: int32(len(queriesProto)), 90 | } 91 | 92 | return &response, nil 93 | } 94 | 95 | func toQueryResponseProto(query *models.Query) (*gatewayv1.Query, error) { 96 | if query == nil { 97 | return &gatewayv1.Query{}, nil 98 | } 99 | return &gatewayv1.Query{ 100 | Id: query.ID, 101 | Text: query.Text, 102 | ServerHost: query.ServerHost, 103 | ClientIp: query.ClientIp, 104 | GroupId: query.GroupId, 105 | BackendId: query.BackendId, 106 | Username: query.Username, 107 | SubmittedAt: query.SubmittedAt, 108 | }, nil 109 | } 110 | 111 | func (s *Server) FindBackendForQuery(ctx context.Context, req *gatewayv1.FindBackendForQueryRequest) (*gatewayv1.FindBackendForQueryResponse, error) { 112 | provider.Logger(ctx).Debugw("FindBackendForQuery", map[string]interface{}{ 113 | "request": req.String(), 114 | }) 115 | 116 | query, err := s.core.GetQuery(ctx, req.QueryId) 117 | if err != nil { 118 | return nil, err 119 | } 120 | return &gatewayv1.FindBackendForQueryResponse{ 121 | BackendId: query.BackendId, 122 | GroupId: query.GroupId, 123 | }, nil 124 | } 125 | -------------------------------------------------------------------------------- /internal/gatewayserver/queryApi/validation.go: -------------------------------------------------------------------------------- 1 | package queryapi 2 | 3 | import ( 4 | "context" 5 | 6 | // validation "github.com/go-ozzo/ozzo-validation/v4" 7 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway" 8 | ) 9 | 10 | // func (cp *CreateParams) Validate() error { 11 | // err := validation.ValidateStruct(cp, 12 | // // id, required, length non zero 13 | // validation.Field(&cp.ID, validation.Required, validation.RuneLength(1, 50)), 14 | 15 | // // Hostname, required, string, length 1-50 16 | // validation.Field(&cp.Hostname, validation.Required, validation.RuneLength(1, 50)), 17 | 18 | // // Scheme, required, string, Union(http, https) 19 | // validation.Field(&cp.Scheme, validation.Required, validation.In("http", "https")), 20 | 21 | // // // last_name, required, string, length 1-30 22 | // // validation.Field(&cp.LastName, validation.Required, validation.RuneLength(1, 30)), 23 | // ) 24 | 25 | // return err 26 | // // if err == nil { 27 | // // return nil 28 | // // } 29 | 30 | // // publicErr := errorclass.ErrValidationFailure.New(""). 31 | // // Wrap(err). 32 | // // WithPublic(&errors.Public{ 33 | // // Description: err.Error(), 34 | // // }) 35 | 36 | // // return publicErr 37 | // } 38 | 39 | func ValidateMultiFetchRequest(ctx context.Context, req *gatewayv1.QueriesListRequest) error { 40 | // err := validation.ValidateStruct( 41 | // req, 42 | // validation.Field(&req.EntityName, 43 | // validation.Required)) 44 | // if err != nil { 45 | // return rzperror.New(ctx, errorCodes.VALIDATION_ERROR, err).Report() 46 | // } 47 | return nil 48 | } 49 | -------------------------------------------------------------------------------- /internal/gatewayserver/repo/backend.go: -------------------------------------------------------------------------------- 1 | package repo 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/razorpay/trino-gateway/internal/gatewayserver/database/dbRepo" 7 | "github.com/razorpay/trino-gateway/internal/gatewayserver/models" 8 | "github.com/razorpay/trino-gateway/internal/provider" 9 | "github.com/razorpay/trino-gateway/pkg/spine" 10 | ) 11 | 12 | type IBackendRepo interface { 13 | Create(ctx context.Context, backend *models.Backend) error 14 | Update(ctx context.Context, backend *models.Backend) error 15 | Find(ctx context.Context, id string) (*models.Backend, error) 16 | FindMany(ctx context.Context, conditions map[string]interface{}) ([]models.Backend, error) 17 | // GetAll(ctx context.Context) ([]models.Backend, error) 18 | // GetAllActive(ctx context.Context) ([]models.Backend, error) 19 | GetAllActiveByIDs(ctx context.Context, ids []string) ([]models.Backend, error) 20 | Delete(ctx context.Context, id string) error 21 | Enable(ctx context.Context, id string) error 22 | Disable(ctx context.Context, id string) error 23 | MarkHealthy(ctx context.Context, id string) error 24 | MarkUnhealthy(ctx context.Context, id string) error 25 | } 26 | 27 | type BackendRepo struct { 28 | repo dbRepo.IDbRepo 29 | } 30 | 31 | // NewBackendRepo returns a new instance of *BackendRepo 32 | func NewBackendRepo(repo dbRepo.IDbRepo) *BackendRepo { 33 | return &BackendRepo{repo: repo} 34 | } 35 | 36 | func (r *BackendRepo) Create(ctx context.Context, backend *models.Backend) error { 37 | err := r.repo.Create(ctx, backend) 38 | if err != nil { 39 | provider.Logger(ctx).WithError(err).Error("backend create failed") 40 | return err 41 | } 42 | 43 | provider.Logger(ctx).Infow("backend created", map[string]interface{}{"backend_id": backend.ID}) 44 | 45 | return nil 46 | } 47 | 48 | func (r *BackendRepo) Update(ctx context.Context, backend *models.Backend) error { 49 | err := r.repo.Update(ctx, backend) 50 | if err != nil { 51 | if err == spine.NoRowAffected { 52 | provider.Logger(ctx).Debugw( 53 | "no row affected by backend update", 54 | map[string]interface{}{"backend_id": backend.ID}, 55 | ) 56 | return nil 57 | } 58 | provider.Logger(ctx).WithError(err).Errorw( 59 | "backend update failed", 60 | map[string]interface{}{"backend_id": backend.ID}, 61 | ) 62 | return err 63 | } 64 | 65 | provider.Logger(ctx).Infow("backend updated", map[string]interface{}{"backend_id": backend.ID}) 66 | 67 | return nil 68 | } 69 | 70 | func (r *BackendRepo) Find(ctx context.Context, id string) (*models.Backend, error) { 71 | backend := models.Backend{} 72 | 73 | err := r.repo.FindByID(ctx, &backend, id) 74 | if err != nil { 75 | return nil, err 76 | } 77 | 78 | return &backend, nil 79 | } 80 | 81 | // Returns list of backend ids which are ready to serve traffic 82 | func (r *BackendRepo) GetAllActiveByIDs(ctx context.Context, ids []string) ([]models.Backend, error) { 83 | var backends []models.Backend 84 | 85 | err := r.repo.FindWithConditionByIDs( 86 | ctx, 87 | &backends, 88 | map[string]interface{}{"is_enabled": true, "is_healthy": true}, 89 | ids, 90 | ) 91 | if err != nil { 92 | return nil, err 93 | } 94 | 95 | return backends, nil 96 | } 97 | 98 | func (r *BackendRepo) FindMany(ctx context.Context, conditions map[string]interface{}) ([]models.Backend, error) { 99 | var backends []models.Backend 100 | 101 | err := r.repo.FindMany(ctx, &backends, conditions) 102 | if err != nil { 103 | return nil, err 104 | } 105 | 106 | return backends, nil 107 | } 108 | 109 | // func (r *BackendRepo) GetAll(ctx context.Context) ([]models.Backend, error) { 110 | // var backends []models.Backend 111 | 112 | // err := r.repo.FindMany(ctx, &backends, make(map[string]interface{})) 113 | // if err != nil { 114 | // return nil, err 115 | // } 116 | 117 | // return &backends, nil 118 | // } 119 | 120 | // func (r *BackendRepo) GetAllActive(ctx context.Context) ([]models.Backend, error) { 121 | // var backends []models.Backend 122 | 123 | // err := r.repo.FindMany(ctx, &backends, map[string]interface{}{"is_enabled": true}) 124 | // if err != nil { 125 | // return nil, err 126 | // } 127 | 128 | // return &backends, nil 129 | // } 130 | 131 | func (r *BackendRepo) Enable(ctx context.Context, id string) error { 132 | provider.Logger(ctx).Infow("backend activation triggered", map[string]interface{}{"backend_id": id}) 133 | 134 | backend, err := r.Find(ctx, id) 135 | if err != nil { 136 | provider.Logger(ctx).Error("backend activation failed: " + err.Error()) 137 | return err 138 | } 139 | 140 | if *backend.IsEnabled { 141 | provider.Logger(ctx).Info("backend activation failed. Already active") 142 | return nil 143 | } 144 | 145 | *backend.IsEnabled = true 146 | 147 | if err := r.repo.Update(ctx, backend); err != nil { 148 | return err 149 | } 150 | 151 | return nil 152 | } 153 | 154 | func (r *BackendRepo) Disable(ctx context.Context, id string) error { 155 | provider.Logger(ctx).Infow("backend deactivation triggered", map[string]interface{}{"backend_id": id}) 156 | 157 | backend, err := r.Find(ctx, id) 158 | if err != nil { 159 | provider.Logger(ctx).Error("backend deactivation failed: " + err.Error()) 160 | return err 161 | } 162 | 163 | if !*backend.IsEnabled { 164 | provider.Logger(ctx).Info("backend deactivation failed. Already inactive") 165 | return nil 166 | } 167 | 168 | *backend.IsEnabled = false 169 | 170 | if err := r.repo.Update(ctx, backend); err != nil { 171 | return err 172 | } 173 | 174 | return nil 175 | } 176 | 177 | func (r *BackendRepo) MarkHealthy(ctx context.Context, id string) error { 178 | provider.Logger(ctx).Infow("backend mark as healthy triggered", map[string]interface{}{"backend_id": id}) 179 | 180 | backend, err := r.Find(ctx, id) 181 | if err != nil { 182 | provider.Logger(ctx).Error("backend mark as healthy failed: " + err.Error()) 183 | return err 184 | } 185 | 186 | if *backend.IsHealthy { 187 | provider.Logger(ctx).Info("backend mark as healthy failed. Already healthy") 188 | return nil 189 | } 190 | 191 | *backend.IsHealthy = true 192 | 193 | if err := r.repo.Update(ctx, backend); err != nil { 194 | return err 195 | } 196 | 197 | return nil 198 | } 199 | 200 | func (r *BackendRepo) MarkUnhealthy(ctx context.Context, id string) error { 201 | provider.Logger(ctx).Infow("backend mark as unhealthy triggered", map[string]interface{}{"backend_id": id}) 202 | 203 | backend, err := r.Find(ctx, id) 204 | if err != nil { 205 | provider.Logger(ctx).Error("backend mark as unhealthy failed: " + err.Error()) 206 | return err 207 | } 208 | 209 | if !*backend.IsHealthy { 210 | provider.Logger(ctx).Info("backend mark as unhealthy failed. Already unhealthy") 211 | return nil 212 | } 213 | 214 | *backend.IsHealthy = false 215 | 216 | if err := r.repo.Update(ctx, backend); err != nil { 217 | return err 218 | } 219 | 220 | return nil 221 | } 222 | 223 | func (r *BackendRepo) Delete(ctx context.Context, id string) error { 224 | provider.Logger(ctx).Infow("backend delete request", map[string]interface{}{"backend_id": id}) 225 | 226 | backend, err := r.Find(ctx, id) 227 | if err != nil { 228 | provider.Logger(ctx).Error("backend delete failed: " + err.Error()) 229 | return err 230 | } 231 | 232 | // _ = backend 233 | 234 | err = r.repo.Delete(ctx, backend) 235 | if err != nil { 236 | return err 237 | } 238 | 239 | return nil 240 | } 241 | -------------------------------------------------------------------------------- /internal/gatewayserver/repo/group.go: -------------------------------------------------------------------------------- 1 | package repo 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | 7 | "github.com/razorpay/trino-gateway/internal/gatewayserver/database/dbRepo" 8 | "github.com/razorpay/trino-gateway/internal/gatewayserver/models" 9 | "github.com/razorpay/trino-gateway/internal/provider" 10 | "github.com/razorpay/trino-gateway/pkg/spine" 11 | "gorm.io/gorm/clause" 12 | ) 13 | 14 | type IGroupRepo interface { 15 | Create(ctx context.Context, group *models.Group) error 16 | Update(ctx context.Context, group *models.Group) error 17 | Find(ctx context.Context, id string) (*models.Group, error) 18 | FindMany(ctx context.Context, conditions map[string]interface{}) ([]models.Group, error) 19 | // GetAll(ctx context.Context) ([]models.Group, error) 20 | // GetAllActive(ctx context.Context) ([]models.Group, error) 21 | Delete(ctx context.Context, id string) error 22 | Enable(ctx context.Context, id string) error 23 | Disable(ctx context.Context, id string) error 24 | } 25 | 26 | type GroupRepo struct { 27 | repo dbRepo.IDbRepo 28 | } 29 | 30 | // NewCore returns a new instance of *Core 31 | func NewGroupRepo(repo dbRepo.IDbRepo) *GroupRepo { 32 | return &GroupRepo{repo: repo} 33 | } 34 | 35 | func (r *GroupRepo) Create(ctx context.Context, group *models.Group) error { 36 | err := r.repo.Create(ctx, group) 37 | if err != nil { 38 | provider.Logger(ctx).WithError(err).Errorw("group create failed", map[string]interface{}{"group_id": group.ID}) 39 | return err 40 | } 41 | 42 | provider.Logger(ctx).Infow("group created", map[string]interface{}{"group_id": group.ID}) 43 | 44 | return nil 45 | } 46 | 47 | func (r *GroupRepo) Update(ctx context.Context, group *models.Group) error { 48 | if group.GroupBackendsMappings != nil { 49 | err := r.repo.ReplaceAssociations(ctx, group, "GroupBackendsMappings", group.GroupBackendsMappings) 50 | if err != nil { 51 | provider.Logger(ctx).WithError(err).Errorw( 52 | "group update failed, unable to update associations", 53 | map[string]interface{}{"group_id": group.ID}) 54 | return err 55 | } 56 | } 57 | err := r.repo.Update(ctx, group) 58 | if err != nil { 59 | if err == spine.NoRowAffected { 60 | provider.Logger(ctx).Debugw( 61 | "no row affected by group update", 62 | map[string]interface{}{"group_id": group.ID}, 63 | ) 64 | return nil 65 | } 66 | provider.Logger(ctx).WithError(err).Errorw( 67 | "group update failed", 68 | map[string]interface{}{"group_id": group.ID}) 69 | return err 70 | } 71 | 72 | provider.Logger(ctx).Infow("group updated", map[string]interface{}{"group_id": group.ID}) 73 | 74 | return nil 75 | } 76 | 77 | func (r *GroupRepo) Find(ctx context.Context, id string) (*models.Group, error) { 78 | group := models.Group{} 79 | 80 | err := r.repo.Preload(ctx, clause.Associations).FindByID(ctx, &group, id) 81 | if err != nil { 82 | return nil, err 83 | } 84 | 85 | return &group, nil 86 | } 87 | 88 | func (r *GroupRepo) FindMany(ctx context.Context, conditions map[string]interface{}) ([]models.Group, error) { 89 | var groups []models.Group 90 | 91 | err := r.repo.Preload(ctx, clause.Associations).FindMany(ctx, &groups, conditions) 92 | if err != nil { 93 | return nil, err 94 | } 95 | 96 | return groups, nil 97 | } 98 | 99 | func (r *GroupRepo) Enable(ctx context.Context, id string) error { 100 | provider.Logger(ctx).Infow("group activation triggered", map[string]interface{}{"group_id": id}) 101 | 102 | group, err := r.Find(ctx, id) 103 | if err != nil { 104 | provider.Logger(ctx).Error("group activation failed: " + err.Error()) 105 | return err 106 | } 107 | 108 | if *group.IsEnabled { 109 | provider.Logger(ctx).Error("group activation failed. Already active") 110 | return errors.New("already active") 111 | } 112 | 113 | *group.IsEnabled = true 114 | 115 | if err := r.repo.Update(ctx, group); err != nil { 116 | return err 117 | } 118 | 119 | return nil 120 | } 121 | 122 | func (r *GroupRepo) Disable(ctx context.Context, id string) error { 123 | provider.Logger(ctx).Infow("group activation triggered", map[string]interface{}{"group_id": id}) 124 | 125 | group, err := r.Find(ctx, id) 126 | if err != nil { 127 | provider.Logger(ctx).Error("group activation failed: " + err.Error()) 128 | return err 129 | } 130 | 131 | if !*group.IsEnabled { 132 | provider.Logger(ctx).Error("group deactivation failed. Already inactive") 133 | return errors.New("already inactive") 134 | } 135 | 136 | *group.IsEnabled = false 137 | 138 | if err := r.repo.Update(ctx, group); err != nil { 139 | return err 140 | } 141 | 142 | return nil 143 | } 144 | 145 | func (r *GroupRepo) Delete(ctx context.Context, id string) error { 146 | provider.Logger(ctx).Infow("group delete request", map[string]interface{}{"group_id": id}) 147 | 148 | group, err := r.Find(ctx, id) 149 | if err != nil { 150 | provider.Logger(ctx).Error("group delete failed: " + err.Error()) 151 | return err 152 | } 153 | 154 | // _ = group 155 | 156 | err = r.repo.Delete(ctx, group) 157 | if err != nil { 158 | return err 159 | } 160 | 161 | return nil 162 | } 163 | -------------------------------------------------------------------------------- /internal/gatewayserver/repo/policy.go: -------------------------------------------------------------------------------- 1 | package repo 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | 7 | "github.com/razorpay/trino-gateway/internal/gatewayserver/database/dbRepo" 8 | "github.com/razorpay/trino-gateway/internal/gatewayserver/models" 9 | "github.com/razorpay/trino-gateway/internal/provider" 10 | "github.com/razorpay/trino-gateway/pkg/spine" 11 | ) 12 | 13 | type IPolicyRepo interface { 14 | Create(ctx context.Context, policy *models.Policy) error 15 | Update(ctx context.Context, policy *models.Policy) error 16 | Find(ctx context.Context, id string) (*models.Policy, error) 17 | FindMany(ctx context.Context, conditions map[string]interface{}) ([]models.Policy, error) 18 | // GetAll(ctx context.Context) ([]models.Policy, error) 19 | // GetAllActive(ctx context.Context) ([]models.Policy, error) 20 | Delete(ctx context.Context, id string) error 21 | Enable(ctx context.Context, id string) error 22 | Disable(ctx context.Context, id string) error 23 | } 24 | 25 | type PolicyRepo struct { 26 | repo dbRepo.IDbRepo 27 | } 28 | 29 | // NewCore returns a new instance of *Core 30 | func NewPolicyRepo(repo dbRepo.IDbRepo) *PolicyRepo { 31 | return &PolicyRepo{repo: repo} 32 | } 33 | 34 | func (r *PolicyRepo) Create(ctx context.Context, policy *models.Policy) error { 35 | err := r.repo.Create(ctx, policy) 36 | if err != nil { 37 | provider.Logger(ctx).WithError(err).Errorw("policy create failed", map[string]interface{}{"id": policy.ID}) 38 | return err 39 | } 40 | 41 | provider.Logger(ctx).Infow("policy created", map[string]interface{}{"id": policy.ID}) 42 | 43 | return nil 44 | } 45 | 46 | func (r *PolicyRepo) Update(ctx context.Context, policy *models.Policy) error { 47 | err := r.repo.Update(ctx, policy) 48 | if err != nil { 49 | if err == spine.NoRowAffected { 50 | provider.Logger(ctx).Debugw( 51 | "no row affected by policy update", 52 | map[string]interface{}{"policy_id": policy.ID}, 53 | ) 54 | return nil 55 | } 56 | provider.Logger(ctx).WithError(err).Errorw( 57 | "policy update failed", 58 | map[string]interface{}{"policy_id": policy.ID}) 59 | return err 60 | } 61 | 62 | provider.Logger(ctx).Infow("policy updated", map[string]interface{}{"id": policy.ID}) 63 | 64 | return nil 65 | } 66 | 67 | func (r *PolicyRepo) Find(ctx context.Context, id string) (*models.Policy, error) { 68 | policy := models.Policy{} 69 | 70 | err := r.repo.FindByID(ctx, &policy, id) 71 | if err != nil { 72 | return nil, err 73 | } 74 | 75 | return &policy, nil 76 | } 77 | 78 | func (r *PolicyRepo) FindMany(ctx context.Context, conditions map[string]interface{}) ([]models.Policy, error) { 79 | var policies []models.Policy 80 | 81 | err := r.repo.FindMany(ctx, &policies, conditions) 82 | if err != nil { 83 | return nil, err 84 | } 85 | 86 | return policies, nil 87 | } 88 | 89 | // func (r *PolicyRepo) GetAll(ctx context.Context) ([]models.Policy, error) { 90 | // var policies []models.Policy 91 | 92 | // err := r.repo.FindMany(ctx, &policies, make(map[string]interface{})) 93 | // if err != nil { 94 | // return nil, err 95 | // } 96 | 97 | // return &policies, nil 98 | // } 99 | 100 | // func (r *PolicyRepo) GetAllActive(ctx context.Context) ([]models.Policy, error) { 101 | // var policies []models.Policy 102 | 103 | // err := r.repo.FindMany(ctx, &policies, map[string]interface{}{"is_enabled": true}) 104 | // if err != nil { 105 | // return nil, err 106 | // } 107 | 108 | // return &policies, nil 109 | // } 110 | 111 | func (r *PolicyRepo) Enable(ctx context.Context, id string) error { 112 | provider.Logger(ctx).Infow("policy activation triggered", map[string]interface{}{"policy_id": id}) 113 | 114 | policy, err := r.Find(ctx, id) 115 | if err != nil { 116 | provider.Logger(ctx).Error("policy activation failed: " + err.Error()) 117 | return err 118 | } 119 | 120 | if *policy.IsEnabled { 121 | provider.Logger(ctx).Error("policy activation failed. Already active") 122 | return errors.New("Already active") 123 | } 124 | 125 | *policy.IsEnabled = true 126 | 127 | if err := r.repo.Update(ctx, policy); err != nil { 128 | return err 129 | } 130 | 131 | return nil 132 | } 133 | 134 | func (r *PolicyRepo) Disable(ctx context.Context, id string) error { 135 | provider.Logger(ctx).Infow("policy activation triggered", map[string]interface{}{"policy_id": id}) 136 | 137 | policy, err := r.Find(ctx, id) 138 | if err != nil { 139 | provider.Logger(ctx).Error("policy activation failed: " + err.Error()) 140 | return err 141 | } 142 | 143 | if !*policy.IsEnabled { 144 | provider.Logger(ctx).Error("policy activation failed. Already active") 145 | return errors.New("Already active") 146 | } 147 | 148 | *policy.IsEnabled = false 149 | 150 | if err := r.repo.Update(ctx, policy); err != nil { 151 | return err 152 | } 153 | 154 | return nil 155 | } 156 | 157 | func (r *PolicyRepo) Delete(ctx context.Context, id string) error { 158 | provider.Logger(ctx).Infow("policy delete request", map[string]interface{}{"policy_id": id}) 159 | 160 | policy, err := r.Find(ctx, id) 161 | if err != nil { 162 | provider.Logger(ctx).Error("policy delete failed: " + err.Error()) 163 | return err 164 | } 165 | 166 | // _ = policy 167 | 168 | err = r.repo.Delete(ctx, policy) 169 | if err != nil { 170 | return err 171 | } 172 | 173 | return nil 174 | } 175 | -------------------------------------------------------------------------------- /internal/gatewayserver/repo/query.go: -------------------------------------------------------------------------------- 1 | package repo 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/razorpay/trino-gateway/internal/gatewayserver/database/dbRepo" 7 | "github.com/razorpay/trino-gateway/internal/gatewayserver/models" 8 | "github.com/razorpay/trino-gateway/internal/provider" 9 | "github.com/razorpay/trino-gateway/pkg/spine" 10 | ) 11 | 12 | type IQueryRepo interface { 13 | Create(ctx context.Context, query *models.Query) error 14 | Update(ctx context.Context, query *models.Query) error 15 | Find(ctx context.Context, id string) (*models.Query, error) 16 | FindMany(ctx context.Context, conditions map[string]interface{}) ([]models.Query, error) 17 | // Find(ctx context.Context, id string) (*Query, error) 18 | // FindAll(ctx context.Context) ([]Query, error) 19 | } 20 | 21 | type QueryRepo struct { 22 | repo dbRepo.IDbRepo 23 | } 24 | 25 | // NewCore returns a new instance of *Core 26 | func NewQueryRepo(repo dbRepo.IDbRepo) *QueryRepo { 27 | return &QueryRepo{repo: repo} 28 | } 29 | 30 | func (r *QueryRepo) Create(ctx context.Context, query *models.Query) error { 31 | err := r.repo.Create(ctx, query) 32 | if err != nil { 33 | provider.Logger(ctx).WithError(err).Errorw("query create failed", map[string]interface{}{"query_id": query.ID}) 34 | return err 35 | } 36 | 37 | provider.Logger(ctx).Infow("query created", map[string]interface{}{"query_id": query.ID}) 38 | 39 | return nil 40 | } 41 | 42 | func (r *QueryRepo) Update(ctx context.Context, query *models.Query) error { 43 | err := r.repo.Update(ctx, query) 44 | if err != nil { 45 | if err == spine.NoRowAffected { 46 | provider.Logger(ctx).Debugw( 47 | "no row affected by query update", 48 | map[string]interface{}{"query_id": query.ID}, 49 | ) 50 | return nil 51 | } 52 | provider.Logger(ctx).WithError(err).Errorw( 53 | "query update failed", 54 | map[string]interface{}{"query_id": query.ID}) 55 | return err 56 | } 57 | 58 | provider.Logger(ctx).Infow("query updated", map[string]interface{}{"query_id": query.ID}) 59 | 60 | return nil 61 | } 62 | 63 | func (r *QueryRepo) Find(ctx context.Context, id string) (*models.Query, error) { 64 | query := models.Query{} 65 | 66 | err := r.repo.FindByID(ctx, &query, id) 67 | if err != nil { 68 | return nil, err 69 | } 70 | 71 | return &query, nil 72 | } 73 | 74 | func (r *QueryRepo) FindMany(ctx context.Context, conditions map[string]interface{}) ([]models.Query, error) { 75 | var queries []models.Query 76 | 77 | err := r.repo.FindMany(ctx, &queries, conditions) 78 | if err != nil { 79 | return nil, err 80 | } 81 | 82 | return queries, nil 83 | } 84 | -------------------------------------------------------------------------------- /internal/monitor/metric.go: -------------------------------------------------------------------------------- 1 | package monitor 2 | 3 | import ( 4 | "github.com/prometheus/client_golang/prometheus" 5 | "github.com/prometheus/client_golang/prometheus/promauto" 6 | "github.com/razorpay/trino-gateway/internal/boot" 7 | ) 8 | 9 | type Metrics struct { 10 | executionsTotal *prometheus.CounterVec 11 | executionlastRunAt *prometheus.GaugeVec 12 | executionDurations *prometheus.HistogramVec 13 | backendLoad *prometheus.GaugeVec 14 | } 15 | 16 | var metrics *Metrics 17 | 18 | func initMetrics() { 19 | env := boot.Config.App.Env 20 | metrics = &Metrics{} 21 | metrics.executionsTotal = promauto.NewCounterVec( 22 | prometheus.CounterOpts{ 23 | Name: "trino_gateway_monitor_executions_total", 24 | Help: "Number of executions triggered for monitor task.", 25 | }, 26 | []string{"env"}, 27 | ).MustCurryWith(prometheus.Labels{"env": env}) 28 | 29 | metrics.executionlastRunAt = promauto.NewGaugeVec( 30 | prometheus.GaugeOpts{ 31 | Name: "trino_gateway_monitor_execution_last_run_at", 32 | Help: "Monitor task last run epoch ts.", 33 | }, 34 | []string{"env"}, 35 | ).MustCurryWith(prometheus.Labels{"env": env}) 36 | 37 | metrics.executionDurations = promauto.NewHistogramVec( 38 | prometheus.HistogramOpts{ 39 | Name: "trino_gateway_monitor_execution_seconds_histogram", 40 | Help: "Monitor task execution time distributions histogram.", 41 | Buckets: []float64{5, 15, 30, 60, 90, 120, 150, 180, 210, 240}, 42 | }, 43 | []string{"env"}, 44 | ).MustCurryWith(prometheus.Labels{"env": env}).(*prometheus.HistogramVec) 45 | 46 | metrics.backendLoad = promauto.NewGaugeVec( 47 | prometheus.GaugeOpts{ 48 | Name: "trino_gateway_monitor_backend_load", 49 | Help: "Backend Load computed by last run of monitor task.", 50 | }, 51 | []string{"env", "backend"}, 52 | ).MustCurryWith(prometheus.Labels{"env": env}) 53 | } 54 | -------------------------------------------------------------------------------- /internal/monitor/monitor.go: -------------------------------------------------------------------------------- 1 | package monitor 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "time" 8 | 9 | "github.com/go-co-op/gocron" 10 | "github.com/razorpay/trino-gateway/internal/provider" 11 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway" 12 | ) 13 | 14 | type Monitor struct { 15 | core ICore 16 | } 17 | 18 | func init() { 19 | initMetrics() 20 | } 21 | 22 | func NewMonitor(core ICore) *Monitor { 23 | return &Monitor{ 24 | core: core, 25 | } 26 | } 27 | 28 | func (m *Monitor) Schedule(ctx *context.Context, interval string) error { 29 | s := gocron.NewScheduler(time.UTC) 30 | j, err := s.Every(interval).Do(m.Execute, ctx) 31 | if err != nil { 32 | return err 33 | } 34 | s.SetMaxConcurrentJobs(1, gocron.RescheduleMode) 35 | s.StartImmediately().StartAsync() 36 | 37 | provider.Logger(*ctx).Infow("Scheduled Monitoring Job", map[string]interface{}{ 38 | "job": fmt.Sprintf("%v", j), 39 | "nextRunUTC": j.NextRun().Local().UTC(), 40 | }) 41 | 42 | return nil 43 | } 44 | 45 | func (m *Monitor) Execute(ctx *context.Context) { 46 | provider.Logger(*ctx).Info("Executing monitoring task") 47 | 48 | metrics.executionsTotal. 49 | WithLabelValues().Inc() 50 | 51 | defer func(st time.Time) { 52 | duration := float64(time.Since(st).Seconds()) 53 | metrics.executionDurations. 54 | WithLabelValues().Observe(duration) 55 | 56 | metrics.executionlastRunAt. 57 | WithLabelValues().SetToCurrentTime() 58 | }(time.Now()) 59 | 60 | provider.Logger(*ctx).Info("Evaluating new state for backends") 61 | newStates, err := m.core.EvaluateBackendNewState(ctx) 62 | if err != nil { 63 | provider.Logger(*ctx).WithError(err).Error("Error evaluating new states for backends") 64 | return 65 | } 66 | 67 | if len(newStates.Healthy) == 0 { 68 | provider.Logger(*ctx).Error("No Backends are in Healthy state.") 69 | } 70 | 71 | provider.Logger(*ctx).Debug("Marking healthy/unhealthy backends as per evaluated states") 72 | var wg sync.WaitGroup 73 | for _, b := range newStates.Unhealthy { 74 | wg.Add(1) 75 | go func(x *gatewayv1.Backend) { 76 | defer wg.Done() 77 | err = m.core.MarkUnhealthyBackend(ctx, x) 78 | if err != nil { 79 | provider.Logger(*ctx).WithError(err).Errorw( 80 | "Failure marking backend as Unhealthy", 81 | map[string]interface{}{"backend": x}) 82 | } 83 | }(b) 84 | } 85 | 86 | for _, b := range newStates.Healthy { 87 | wg.Add(1) 88 | go func(x *gatewayv1.Backend) { 89 | defer wg.Done() 90 | err = m.core.MarkHealthyBackend(ctx, x) 91 | if err != nil { 92 | provider.Logger(*ctx).WithError(err).Errorw( 93 | "Failure marking backend as Healthy", 94 | map[string]interface{}{"backend": x}) 95 | } 96 | }(b) 97 | } 98 | 99 | // Wait for all backend health updates to complete. 100 | wg.Wait() 101 | 102 | provider.Logger(*ctx).Info("Finished executing monitoring task") 103 | } 104 | -------------------------------------------------------------------------------- /internal/provider/logger.go: -------------------------------------------------------------------------------- 1 | package provider 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/razorpay/trino-gateway/pkg/logger" 7 | ) 8 | 9 | // Logger will provider the logger instance 10 | func Logger(ctx context.Context) *logger.Entry { 11 | ctxLogger, err := logger.Ctx(ctx) 12 | if err == nil { 13 | return ctxLogger 14 | } 15 | 16 | panic(err.Error()) 17 | } 18 | -------------------------------------------------------------------------------- /internal/router/auth.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "io/ioutil" 9 | "net/http" 10 | "time" 11 | 12 | "github.com/razorpay/trino-gateway/internal/boot" 13 | "github.com/razorpay/trino-gateway/internal/provider" 14 | "github.com/razorpay/trino-gateway/internal/router/trinoheaders" 15 | "github.com/razorpay/trino-gateway/internal/utils" 16 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway" 17 | ) 18 | 19 | type IAuthService interface { 20 | Authenticate(ctx *context.Context, username string, password string) (bool, error) 21 | } 22 | 23 | type AuthService struct { 24 | ValidationProviderURL string 25 | ValidationProviderToken string 26 | } 27 | 28 | /* 29 | Calls an Auth Token Validator Service with following api contract: 30 | With Params- 31 | 32 | { 33 | "email": "abc@xyz.com", 34 | "token": "token123" 35 | } 36 | 37 | api returns- 38 | If Authenticated - {"ok": true} 39 | If not authenticated- {"ok": false} 40 | 41 | @returns- 42 | boolean{True or False},error_message 43 | */ 44 | func (s *AuthService) ValidateFromValidationProvider(ctx *context.Context, username string, password string) (bool, error) { 45 | payload := struct { 46 | Username string `json:"email"` 47 | Token string `json:"token"` 48 | }{ 49 | Username: username, 50 | Token: password, 51 | } 52 | 53 | payloadBytes, _ := json.Marshal(payload) 54 | req, _ := http.NewRequest("POST", s.ValidationProviderURL, bytes.NewReader(payloadBytes)) 55 | req.Header.Set("X-Auth-Token", s.ValidationProviderToken) 56 | req.Header.Set("Content-Type", "application/json") 57 | 58 | client := &http.Client{} 59 | resp, err := client.Do(req) 60 | if err != nil { 61 | return false, err 62 | } 63 | defer resp.Body.Close() 64 | 65 | respBody, _ := ioutil.ReadAll(resp.Body) 66 | 67 | var data struct { 68 | OK bool `json:"ok"` 69 | } 70 | jsonParseError := json.Unmarshal([]byte(respBody), &data) 71 | if jsonParseError != nil { 72 | return false, jsonParseError 73 | } 74 | 75 | return data.OK, nil 76 | } 77 | 78 | func (s *AuthService) Authenticate(ctx *context.Context, username string, password string) (bool, error) { 79 | authCache := s.GetInMemoryAuthCache(ctx) 80 | 81 | if entry, exists := authCache.Get(username); exists && entry == password { 82 | authCache.Update(username, password) 83 | return true, nil 84 | } 85 | 86 | isValid, err := s.ValidateFromValidationProvider(ctx, username, password) 87 | if err != nil { 88 | return false, err 89 | } 90 | 91 | if isValid { 92 | authCache.Update(username, password) 93 | } 94 | 95 | return isValid, nil 96 | } 97 | 98 | func (s *AuthService) GetInMemoryAuthCache(ctx *context.Context) utils.ISimpleCache { 99 | ctxKeyName := "routerAuthCache" 100 | authCache, ok := (*ctx).Value(ctxKeyName).(*utils.InMemorySimpleCache) 101 | 102 | if !ok { 103 | 104 | expiryInterval, _ := time.ParseDuration(boot.Config.Auth.Router.DelegatedAuth.CacheTTLMinutes) 105 | 106 | authCache = &utils.InMemorySimpleCache{ 107 | Cache: make(map[string]struct { 108 | Timestamp time.Time 109 | Value string 110 | }), 111 | ExpiryInterval: expiryInterval, 112 | } 113 | *ctx = context.WithValue(*ctx, ctxKeyName, authCache) 114 | } 115 | return authCache 116 | } 117 | 118 | func (r *RouterServer) isAuthDelegated(ctx *context.Context) (bool, error) { 119 | res, err := r.gatewayApiClient.Policy.EvaluateAuthDelegationForClient(*ctx, &gatewayv1.EvaluateAuthDelegationRequest{IncomingPort: int32(r.port)}) 120 | 121 | if err != nil { 122 | provider.Logger(*ctx).WithError(err).Errorw( 123 | fmt.Sprint(LOG_TAG, "Failed to evaluate auth delegation policy. Assuming delegation is disabled."), 124 | map[string]interface{}{ 125 | "port": r.port, 126 | }) 127 | return false, err 128 | } 129 | return res.GetIsAuthDelegated(), nil 130 | } 131 | 132 | func (r *RouterServer) AuthHandler(ctx *context.Context, h http.Handler) http.Handler { 133 | return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 134 | if isAuth, _ := r.isAuthDelegated(ctx); isAuth { 135 | // TODO: Refactor auth type handling to a dedicated type 136 | 137 | // BasicAuth 138 | username, password, isBasicAuth := req.BasicAuth() 139 | 140 | // CustomAuth 141 | if !isBasicAuth { 142 | provider.Logger(*ctx).Debug("Custom Auth type") 143 | username = trinoheaders.Get(trinoheaders.User, req) 144 | password = trinoheaders.Get(trinoheaders.Password, req) 145 | } else { 146 | if u := trinoheaders.Get(trinoheaders.User, req); u != username { 147 | errorMsg := fmt.Sprintf("Username from basicauth - %s does not match with User principal - %s", username, u) 148 | provider.Logger(*ctx).Debug(errorMsg) 149 | http.Error(w, errorMsg, http.StatusUnauthorized) 150 | } 151 | 152 | // Remove auth details from request 153 | req.Header.Del("Authorization") 154 | } 155 | 156 | // NoAuth 157 | isNoAuth := password == "" 158 | if isNoAuth { 159 | provider.Logger(*ctx).Debug("No Auth type detected") 160 | errorMsg := fmt.Sprintf("Password required") 161 | http.Error(w, errorMsg, http.StatusUnauthorized) 162 | return 163 | } 164 | 165 | isAuthenticated, err := r.authService.Authenticate(ctx, username, password) 166 | 167 | if err != nil { 168 | errorMsg := fmt.Sprintf("Unable to Authenticate users. Getting error - %s", err) 169 | provider.Logger(*ctx).Error(errorMsg) 170 | http.Error(w, "Unable to Authenticate the user", http.StatusNotFound) 171 | return 172 | } 173 | if !isAuthenticated { 174 | provider.Logger(*ctx).Debug(fmt.Sprintf("User - %s not authenticated", username)) 175 | http.Error(w, "User not authenticated", http.StatusUnauthorized) 176 | return 177 | } 178 | h.ServeHTTP(w, req) 179 | } else { 180 | h.ServeHTTP(w, req) 181 | } 182 | }) 183 | } 184 | -------------------------------------------------------------------------------- /internal/router/auth_test.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/razorpay/trino-gateway/pkg/logger" 8 | "github.com/stretchr/testify/suite" 9 | ) 10 | 11 | // Define the suite, and absorb the built-in basic suite 12 | // functionality from testify - including a T() method which 13 | // returns the current testing context 14 | type AuthSuite struct { 15 | suite.Suite 16 | authService *AuthService 17 | ctx *context.Context 18 | } 19 | 20 | func (suite *AuthSuite) SetupTest() { 21 | lgrConfig := logger.Config{ 22 | LogLevel: logger.Warn, 23 | } 24 | 25 | l, err := logger.NewLogger(lgrConfig) 26 | if err != nil { 27 | panic("failed to initialize logger") 28 | } 29 | 30 | c := context.WithValue(context.Background(), logger.LoggerCtxKey, l) 31 | 32 | suite.ctx = &c 33 | suite.authService = &AuthService{} 34 | } 35 | 36 | func (suite *AuthSuite) Test_GetInMemoryAuthCache_Persistance() { 37 | key := "testKey" 38 | value := "testValue" 39 | 40 | authCache := suite.authService.GetInMemoryAuthCache(suite.ctx) 41 | authCache.Update(key, value) 42 | 43 | authCacheInstance2 := suite.authService.GetInMemoryAuthCache(suite.ctx) 44 | entry, exists := authCacheInstance2.Get(key) 45 | 46 | suite.Truef(exists, "Second cache instance doesn't have same key") 47 | if exists { 48 | suite.Equalf(value, entry, "Second Cache instance value doesn't match.") 49 | } 50 | 51 | } 52 | func TestAuthSuite(t *testing.T) { 53 | suite.Run(t, new(AuthSuite)) 54 | } 55 | -------------------------------------------------------------------------------- /internal/router/metric.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "github.com/prometheus/client_golang/prometheus" 5 | "github.com/prometheus/client_golang/prometheus/promauto" 6 | "github.com/razorpay/trino-gateway/internal/boot" 7 | ) 8 | 9 | type Metrics struct { 10 | requestsReceivedTotal *prometheus.CounterVec 11 | requestsRoutedTotal *prometheus.CounterVec 12 | requestPreRoutingDelays *prometheus.HistogramVec 13 | requestPostRoutingDelays *prometheus.HistogramVec 14 | responsesSentTotal *prometheus.CounterVec 15 | responseDurations *prometheus.HistogramVec 16 | } 17 | 18 | var metrics *Metrics 19 | 20 | func initMetrics() { 21 | env := boot.Config.App.Env 22 | metrics = &Metrics{} 23 | metrics.requestsReceivedTotal = promauto.NewCounterVec( 24 | prometheus.CounterOpts{ 25 | Name: "trino_gateway_router_http_requests_total", 26 | Help: "Number of HTTP requests received from clients.", 27 | }, 28 | []string{"env", "method", "port"}, 29 | ).MustCurryWith(prometheus.Labels{"env": env}) 30 | 31 | metrics.requestsRoutedTotal = promauto.NewCounterVec( 32 | prometheus.CounterOpts{ 33 | Name: "trino_gateway_router_http_requests_routed_total", 34 | Help: "Number of HTTP requests routed to a trino server.", 35 | }, 36 | []string{"env", "method", "port", "group", "backend"}, 37 | ).MustCurryWith(prometheus.Labels{"env": env}) 38 | 39 | metrics.requestPreRoutingDelays = promauto.NewHistogramVec( 40 | prometheus.HistogramOpts{ 41 | Name: "trino_gateway_router_http_pre_routing_delay_ms_histogram", 42 | Help: "Delay in routing client request to a Trino server, latency distributions histogram.", 43 | Buckets: []float64{5, 10, 15, 20, 30, 40, 60, 100, 150, 500}, 44 | }, 45 | []string{"env", "method"}, 46 | ).MustCurryWith(prometheus.Labels{"env": env}).(*prometheus.HistogramVec) 47 | 48 | metrics.requestPostRoutingDelays = promauto.NewHistogramVec( 49 | prometheus.HistogramOpts{ 50 | Name: "trino_gateway_router_http_post_routing_delay_ms_histogram", 51 | Help: "Delay in sending response to client after receiving response from Trino server, latency distributions histogram.", 52 | Buckets: []float64{2, 5, 10, 15, 20, 25, 30, 40, 50, 100, 500}, 53 | }, 54 | []string{"env", "method", "code"}, 55 | ).MustCurryWith(prometheus.Labels{"env": env}).(*prometheus.HistogramVec) 56 | 57 | metrics.responsesSentTotal = promauto.NewCounterVec( 58 | prometheus.CounterOpts{ 59 | Name: "trino_gateway_router_http_responses_total", 60 | Help: "Number of HTTP responses sent back to client.", 61 | }, 62 | []string{"env", "method", "code"}, 63 | ).MustCurryWith(prometheus.Labels{"env": env}) 64 | 65 | metrics.responseDurations = promauto.NewHistogramVec( 66 | prometheus.HistogramOpts{ 67 | Name: "trino_gateway_router_http_durations_ms_histogram", 68 | Help: "Router HTTP latency distributions histogram for responses sent to clients.", 69 | Buckets: []float64{20, 40, 60, 90, 120, 150, 200, 250, 300, 500}, 70 | }, 71 | []string{"env", "method", "code"}, 72 | ).MustCurryWith(prometheus.Labels{"env": env}).(*prometheus.HistogramVec) 73 | } 74 | -------------------------------------------------------------------------------- /internal/router/request_test.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/razorpay/trino-gateway/pkg/logger" 8 | "github.com/stretchr/testify/suite" 9 | ) 10 | 11 | type HelpersSuite struct { 12 | suite.Suite 13 | ctx *context.Context 14 | } 15 | 16 | func (suite *HelpersSuite) SetupTest() { 17 | lgrConfig := logger.Config{ 18 | LogLevel: logger.Warn, 19 | } 20 | 21 | l, err := logger.NewLogger(lgrConfig) 22 | if err != nil { 23 | panic("failed to initialize logger") 24 | } 25 | 26 | c := context.WithValue(context.Background(), logger.LoggerCtxKey, l) 27 | 28 | suite.ctx = &c 29 | } 30 | 31 | func (suite *HelpersSuite) Test_extractQueryId() { 32 | } 33 | 34 | func (suite *HelpersSuite) Test_isValidRequest() { 35 | } 36 | 37 | func (suite *HelpersSuite) Test_constructQueryFromReq() { 38 | } 39 | 40 | func TestSuite(t *testing.T) { 41 | suite.Run(t, new(HelpersSuite)) 42 | } 43 | -------------------------------------------------------------------------------- /internal/router/request_type.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "fmt" 5 | 6 | gatewayv1 "github.com/razorpay/trino-gateway/rpc/gateway" 7 | ) 8 | 9 | type ClientRequest interface { 10 | isClientRequest() 11 | Validate() error 12 | } 13 | 14 | type UiRequest struct { 15 | ClientRequest 16 | queryId string 17 | } 18 | 19 | func (UiRequest) isClientRequest() {} 20 | 21 | func (r UiRequest) Validate() error { 22 | if r.queryId == "" { 23 | tag := "ui" 24 | return fmt.Errorf("%s: %s", tag, "Missing query id") 25 | } 26 | return nil 27 | } 28 | 29 | type QueryApiRequest struct { 30 | ClientRequest 31 | headerConnectionProperties string 32 | headerClientTags string 33 | incomingPort int32 34 | transactionId string 35 | clientHost string 36 | Query *gatewayv1.Query 37 | } 38 | 39 | func (QueryApiRequest) isClientRequest() {} 40 | func (r QueryApiRequest) Validate() error { 41 | tag := "query api" 42 | if r.Query.GetUsername() == "" { 43 | return fmt.Errorf("%s: %s", tag, "Missing Trino Username header") 44 | } 45 | if r.Query.GetId() == "" { 46 | return fmt.Errorf("%s: %s", tag, "Missing Query Id") 47 | } 48 | 49 | // TODO: remove it once transaction support is added 50 | // Looker's Presto client sends `X-Presto-Transaction-Id: NONE` 51 | // whereas trino client doesnt send it if its not set 52 | if !(r.transactionId == "" || r.transactionId == "NONE") { 53 | return fmt.Errorf("%s: %s", tag, "Transactions are not supported in gateway.") 54 | } 55 | return nil 56 | } 57 | 58 | type QueryRequest struct { 59 | ClientRequest 60 | headerConnectionProperties string 61 | headerClientTags string 62 | incomingPort int32 63 | transactionId string 64 | clientHost string 65 | Query *gatewayv1.Query 66 | } 67 | 68 | func (QueryRequest) isClientRequest() {} 69 | func (r QueryRequest) Validate() error { 70 | tag := "query submission" 71 | if r.Query.GetUsername() == "" { 72 | return fmt.Errorf("%s: %s", tag, "Missing Trino Username header") 73 | } 74 | if r.Query.GetText() == "" { 75 | return fmt.Errorf("%s: %s", tag, "Missing Query text") 76 | } 77 | 78 | // TODO: remove it once transaction support is added 79 | // Looker's Presto client sends `X-Presto-Transaction-Id: NONE` 80 | // whereas trino client doesnt send it if its not set 81 | if !(r.transactionId == "" || r.transactionId == "NONE") { 82 | return fmt.Errorf("%s: %s", tag, "Transactions are not supported in gateway.") 83 | } 84 | return nil 85 | } 86 | 87 | type ApiRequest struct { 88 | ClientRequest 89 | } 90 | 91 | func (ApiRequest) isClientRequest() {} 92 | func (r ApiRequest) Validate() error { 93 | return nil 94 | } 95 | -------------------------------------------------------------------------------- /internal/router/response.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "regexp" 9 | "time" 10 | 11 | "github.com/razorpay/trino-gateway/internal/provider" 12 | "github.com/razorpay/trino-gateway/internal/utils" 13 | ) 14 | 15 | func (r *RouterServer) handleRedirect(ctx *context.Context, resp *http.Response) error { 16 | // This needs more testing with looker 17 | 18 | // Handle Redirects 19 | // TODO: Clean it up 20 | // TODO - validate its working for all use cases 21 | regex := regexp.MustCompile(`\w+\:\/\/[^\/]*(.*)`) 22 | if resp.Header.Get("Location") != ("") { 23 | oldLoc := resp.Header.Get("Location") 24 | newLoc := fmt.Sprint("http://", r.routerHostname, regex.ReplaceAllString(oldLoc, "$1")) 25 | resp.Header.Set("Location", newLoc) 26 | } 27 | return nil 28 | } 29 | 30 | func (r *RouterServer) ProcessResponse( 31 | ctx *context.Context, 32 | resp *http.Response, 33 | cReq ClientRequest, 34 | ) error { 35 | switch stCode := resp.StatusCode; true { 36 | case stCode >= 200 && stCode < 300: 37 | // TODO - fix redirect 38 | _ = r.handleRedirect(ctx, resp) 39 | case stCode >= 300 && stCode < 400: 40 | // http3xx -> server sent redirection, gateway doesn't need to modify anything here 41 | // Assuming Clients can directly connect to redirected Uri 42 | default: 43 | provider.Logger(*ctx).Errorw( 44 | fmt.Sprint(LOG_TAG, "Routing unsuccessful"), 45 | map[string]interface{}{ 46 | "serverResponse": utils.StringifyHttpRequestOrResponse(ctx, resp), 47 | }) 48 | return nil 49 | } 50 | 51 | provider.Logger(*ctx).Debug(LOG_TAG + "Routing successful") 52 | 53 | switch nt := cReq.(type) { 54 | case *ApiRequest: 55 | return nil 56 | case *UiRequest: 57 | return nil 58 | case *QueryRequest: 59 | req := nt.Query 60 | body, err := utils.ParseHttpPayloadBody(ctx, &resp.Body, utils.GetHttpBodyEncoding(ctx, resp)) 61 | if err != nil { 62 | provider.Logger(*ctx).WithError(err).Error(fmt.Sprint(LOG_TAG, "unable to parse body of server response")) 63 | } 64 | 65 | go func() { 66 | req.Id = extractQueryIdFromServerResponse(ctx, body) 67 | req.SubmittedAt = time.Now().Unix() 68 | 69 | _, err = r.gatewayApiClient.Query.CreateOrUpdateQuery(*ctx, req) 70 | if err != nil { 71 | provider.Logger( 72 | *ctx).WithError(err).Errorw( 73 | fmt.Sprint(LOG_TAG, "Unable to save query"), 74 | map[string]interface{}{ 75 | "query_id": req.Id, 76 | }) 77 | } 78 | }() 79 | 80 | provider.Logger(*ctx).Debugw("Server Response Processed", map[string]interface{}{ 81 | "resp": utils.StringifyHttpRequestOrResponse(ctx, resp), 82 | }) 83 | 84 | return nil 85 | default: 86 | return nil 87 | } 88 | } 89 | 90 | func extractQueryIdFromServerResponse(ctx *context.Context, body string) string { 91 | provider.Logger(*ctx).Debugw(fmt.Sprint(LOG_TAG, "extracting queryId from server response"), 92 | map[string]interface{}{ 93 | "body": body, 94 | }) 95 | var resp struct{ Id string } 96 | json.Unmarshal([]byte(body), &resp) 97 | return resp.Id 98 | } 99 | -------------------------------------------------------------------------------- /internal/router/trinoheaders/trino.go: -------------------------------------------------------------------------------- 1 | package trinoheaders 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | ) 7 | 8 | // https://github.com/trinodb/trino/blob/master/client/trino-client/src/main/java/io/trino/client/ProtocolHeaders.java 9 | // Connection properties is not part of trino client protocol but is sent from some jdbc clients 10 | const ( 11 | PreparedStatement = "Prepared-Statement" 12 | User = "User" 13 | ClientTags = "Client-Tags" 14 | ConnectionProperties = "Connection-Properties" 15 | TransactionId = "Transaction-Id" 16 | Password = "Password" 17 | Source = "Source" 18 | ) 19 | 20 | var allowedPrefixes = [...]string{"Presto", "Trino"} 21 | 22 | func Get(key string, req *http.Request) string { 23 | for _, h := range allowedPrefixes { 24 | s := fmt.Sprintf("X-%s-%s", h, key) 25 | 26 | if val := req.Header.Get(s); val != "" { 27 | return val 28 | } 29 | } 30 | return "" 31 | } 32 | -------------------------------------------------------------------------------- /internal/router/trinoheaders/trino_test.go: -------------------------------------------------------------------------------- 1 | package trinoheaders 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func Test_Get(t *testing.T) { 11 | trinoHttpReq := &http.Request{ 12 | Header: map[string][]string{ 13 | "X-Trino-User": {"user"}, 14 | "X-Trino-Connection-Properties": {"connProps"}, 15 | }, 16 | } 17 | assert.Equal(t, Get("User", trinoHttpReq), "user") 18 | assert.Equal(t, Get("Connection-Properties", trinoHttpReq), "connProps") 19 | 20 | prestoHttpReq := &http.Request{ 21 | Header: map[string][]string{ 22 | "X-Presto-User": {"user"}, 23 | "X-Presto-Connection-Properties": {"connProps"}, 24 | }, 25 | } 26 | assert.Equal(t, Get("User", prestoHttpReq), "user") 27 | assert.Equal(t, Get("Connection-Properties", prestoHttpReq), "connProps") 28 | } 29 | -------------------------------------------------------------------------------- /internal/utils/utils.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "bytes" 5 | "compress/gzip" 6 | "context" 7 | "io" 8 | "io/ioutil" 9 | "net/http" 10 | "net/http/httputil" 11 | "sync" 12 | "time" 13 | 14 | "github.com/razorpay/trino-gateway/internal/provider" 15 | "github.com/robfig/cron/v3" 16 | ) 17 | 18 | /* 19 | Checks whether provided time object is in 1 minute window of cron expression 20 | */ 21 | func IsTimeInCron(ctx *context.Context, t time.Time, sched string) (bool, error) { 22 | s, err := cron.ParseStandard(sched) 23 | if err != nil { 24 | return false, err 25 | } 26 | nextRun := s.Next(t) 27 | 28 | provider.Logger(*ctx).Debugw( 29 | "Evaluated next valid ts from cron expression", 30 | map[string]interface{}{ 31 | "providedTime": t, 32 | "nextRun": nextRun, 33 | }, 34 | ) 35 | 36 | return nextRun.Sub(t).Minutes() <= 1, nil 37 | } 38 | 39 | func SliceContains[T comparable](collection []T, element T) bool { 40 | for _, item := range collection { 41 | if item == element { 42 | return true 43 | } 44 | } 45 | 46 | return false 47 | } 48 | 49 | // Finds intersection of 2 slices 50 | func SimpleSliceIntersection[T comparable](list1 []T, list2 []T) []T { 51 | result := []T{} 52 | seen := map[T]struct{}{} 53 | 54 | for _, elem := range list1 { 55 | seen[elem] = struct{}{} 56 | } 57 | 58 | for _, elem := range list2 { 59 | if _, ok := seen[elem]; ok { 60 | result = append(result, elem) 61 | } 62 | } 63 | 64 | return result 65 | } 66 | 67 | func GetHttpBodyEncoding[T *http.Request | *http.Response](ctx *context.Context, r T) string { 68 | enc := "" 69 | headerKey := "Content-Encoding" 70 | switch v := any(r).(type) { 71 | case *http.Request: 72 | enc = v.Header.Get(headerKey) 73 | case *http.Response: 74 | enc = v.Header.Get(headerKey) 75 | } 76 | return enc 77 | } 78 | 79 | func StringifyHttpRequestOrResponse[T *http.Request | *http.Response](ctx *context.Context, r T) string { 80 | canDumpBody := GetHttpBodyEncoding(ctx, r) == "" 81 | if !canDumpBody { 82 | provider.Logger(*ctx).Debug( 83 | "Encoded body in http payload, assuming binary data and skipping dump of body") 84 | } 85 | var res []byte 86 | var err error 87 | switch v := any(r).(type) { 88 | case *http.Request: 89 | res, err = httputil.DumpRequest(v, canDumpBody) 90 | case *http.Response: 91 | res, err = httputil.DumpResponse(v, canDumpBody) 92 | } 93 | if err != nil { 94 | provider.Logger(*ctx).Errorw( 95 | "Unable to stringify http payload", 96 | map[string]interface{}{ 97 | "error": err.Error(), 98 | }) 99 | } 100 | return string(res) 101 | } 102 | 103 | func ParseHttpPayloadBody(ctx *context.Context, body *io.ReadCloser, encoding string) (string, error) { 104 | bodyBytes, err := io.ReadAll(*body) 105 | if err != nil { 106 | return "", err 107 | } 108 | // since its a ReadCloser type, the stream will be empty after its read once 109 | // ensure it is restored in original object 110 | *body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes)) 111 | 112 | switch encoding { 113 | case "gzip": 114 | var reader io.ReadCloser 115 | reader, err = gzip.NewReader(bytes.NewReader([]byte(bodyBytes))) 116 | if err != nil { 117 | provider.Logger( 118 | *ctx).WithError(err).Error( 119 | "Unable to decompress gzip encoded response") 120 | } 121 | defer reader.Close() 122 | bb, err := io.ReadAll(reader) 123 | if err != nil { 124 | return "", err 125 | } 126 | 127 | return string(bb), nil 128 | default: 129 | return string(bodyBytes), nil 130 | } 131 | } 132 | 133 | type ISimpleCache interface { 134 | Get(key string) (string, bool) 135 | Update(key, value string) 136 | } 137 | 138 | type InMemorySimpleCache struct { 139 | Cache map[string]struct { 140 | Timestamp time.Time 141 | Value string 142 | } 143 | ExpiryInterval time.Duration 144 | mu sync.Mutex 145 | } 146 | 147 | func (authCache *InMemorySimpleCache) Get(key string) (string, bool) { 148 | authCache.mu.Lock() 149 | defer authCache.mu.Unlock() 150 | 151 | entry, found := authCache.Cache[key] 152 | 153 | if !found { 154 | return "", false 155 | } 156 | 157 | if authCache.ExpiryInterval > 0 && 158 | time.Since(entry.Timestamp) > authCache.ExpiryInterval { 159 | // If entry is older than cachedDuration, then delete the record and return false 160 | delete(authCache.Cache, key) 161 | return "", false 162 | } 163 | return entry.Value, true 164 | } 165 | 166 | func (authCache *InMemorySimpleCache) Update(key, value string) { 167 | authCache.mu.Lock() 168 | defer authCache.mu.Unlock() 169 | 170 | authCache.Cache[key] = struct { 171 | Timestamp time.Time 172 | Value string 173 | }{ 174 | Timestamp: time.Now(), 175 | Value: value, 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /internal/utils/utils_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "bytes" 5 | "compress/gzip" 6 | "context" 7 | "io" 8 | "strings" 9 | "testing" 10 | "time" 11 | 12 | "github.com/razorpay/trino-gateway/pkg/logger" 13 | "github.com/stretchr/testify/suite" 14 | ) 15 | 16 | // Define the suite, and absorb the built-in basic suite 17 | // functionality from testify - including a T() method which 18 | // returns the current testing context 19 | type UtilsSuite struct { 20 | suite.Suite 21 | ctx *context.Context 22 | } 23 | 24 | func (suite *UtilsSuite) SetupTest() { 25 | lgrConfig := logger.Config{ 26 | LogLevel: logger.Warn, 27 | } 28 | 29 | l, err := logger.NewLogger(lgrConfig) 30 | if err != nil { 31 | panic("failed to initialize logger") 32 | } 33 | 34 | c := context.WithValue(context.Background(), logger.LoggerCtxKey, l) 35 | 36 | suite.ctx = &c 37 | } 38 | 39 | func (suite *UtilsSuite) Test_IsTimeInCron() { 40 | // func (c *Core) isCurrentTimeInCron(ctx *context.Context, sched string) (bool, error) 41 | 42 | tst := func(t time.Time, sched string) bool { 43 | s, _ := IsTimeInCron(suite.ctx, t, sched) 44 | return s 45 | } 46 | 47 | suite.Equalf( 48 | true, 49 | tst(time.Now(), "* * * * *"), 50 | "Failure", 51 | ) 52 | } 53 | 54 | func (suite *UtilsSuite) Test_getHttpBodyEncoding() { 55 | } 56 | 57 | func (suite *UtilsSuite) Test_stringifyHttpRequestOrResponse() { 58 | } 59 | 60 | func (suite *UtilsSuite) Test_parseBody() { 61 | str := "body" 62 | stringReader := strings.NewReader(str) 63 | stringReadCloser := io.NopCloser(stringReader) 64 | 65 | tst := func() string { 66 | s, _ := ParseHttpPayloadBody(suite.ctx, &stringReadCloser, "") 67 | return s 68 | } 69 | 70 | suite.Equalf(str, tst(), "Failed to extract string from body") 71 | suite.Equalf(str, tst(), "String extraction is not idempotent") 72 | 73 | var b bytes.Buffer 74 | gz := gzip.NewWriter(&b) 75 | if _, err := gz.Write([]byte(str)); err != nil { 76 | panic(err) 77 | } 78 | if err := gz.Close(); err != nil { 79 | panic(err) 80 | } 81 | 82 | strGzipped := b.String() 83 | stringReaderGzipped := strings.NewReader(strGzipped) 84 | stringReadCloserGzipped := io.NopCloser(stringReaderGzipped) 85 | 86 | tst_gzipped := func() string { 87 | s, _ := ParseHttpPayloadBody(suite.ctx, &stringReadCloserGzipped, "gzip") 88 | return s 89 | } 90 | 91 | suite.Equalf(str, tst_gzipped(), "Failed to extract string from body") 92 | suite.Equalf(str, tst_gzipped(), "String extraction is not idempotent") 93 | } 94 | 95 | func (suite *UtilsSuite) Test_InMemorySimpleCache_Get() { 96 | authCache := &InMemorySimpleCache{ 97 | Cache: make(map[string]struct { 98 | Timestamp time.Time 99 | Value string 100 | }), 101 | } 102 | key := "testKey" 103 | value := "testValue" 104 | authCache.Cache[key] = struct { 105 | Timestamp time.Time 106 | Value string 107 | }{ 108 | Timestamp: time.Now(), 109 | Value: value, 110 | } 111 | 112 | entry, exists := authCache.Get(key) 113 | suite.Truef(exists, "Entry not found in cache.") 114 | if exists { 115 | suite.Equalf(value, entry, "Cached value doesn't match.") 116 | } 117 | } 118 | 119 | func (suite *UtilsSuite) Test_InMemorySimpleCache_Get_InfiniteExpiry() { 120 | authCache := &InMemorySimpleCache{ 121 | Cache: make(map[string]struct { 122 | Timestamp time.Time 123 | Value string 124 | }), 125 | ExpiryInterval: 0 * time.Second, 126 | } 127 | key := "testKey" 128 | value := "testValue" 129 | authCache.Cache[key] = struct { 130 | Timestamp time.Time 131 | Value string 132 | }{ 133 | Timestamp: time.Now().Add(-1000 * time.Hour), 134 | Value: value, 135 | } 136 | 137 | entry, exists := authCache.Get(key) 138 | suite.Truef(exists, "Entry not found in cache.") 139 | if exists { 140 | suite.Equalf(value, entry, "Cached value doesn't match.") 141 | } 142 | } 143 | 144 | func (suite *UtilsSuite) Test_InMemorySimpleCache_Get_Expired() { 145 | expiryInterval := 2 * time.Second 146 | authCache := &InMemorySimpleCache{ 147 | Cache: make(map[string]struct { 148 | Timestamp time.Time 149 | Value string 150 | }), 151 | ExpiryInterval: expiryInterval, 152 | } 153 | key := "testKey" 154 | value := "testValue" 155 | authCache.Cache[key] = struct { 156 | Timestamp time.Time 157 | Value string 158 | }{ 159 | Timestamp: time.Now().Add(-1 * expiryInterval).Add(-1 * time.Second), 160 | Value: value, 161 | } 162 | 163 | _, exists := authCache.Get(key) 164 | suite.False(exists, "Entry not expired.") 165 | } 166 | 167 | func (suite *UtilsSuite) Test_InMemorySimpleCache_Update() { 168 | authCache := &InMemorySimpleCache{ 169 | Cache: make(map[string]struct { 170 | Timestamp time.Time 171 | Value string 172 | }), 173 | } 174 | key := "testKey" 175 | value := "testValue" 176 | authCache.Update(key, value) 177 | 178 | entry, exists := authCache.Cache[key] 179 | suite.Truef(exists, "Entry not found in cache.") 180 | if exists { 181 | suite.Equalf(value, entry.Value, "Cached value doesn't match.") 182 | } 183 | } 184 | 185 | func TestSuite(t *testing.T) { 186 | suite.Run(t, new(UtilsSuite)) 187 | } 188 | -------------------------------------------------------------------------------- /pkg/config/config.go: -------------------------------------------------------------------------------- 1 | // Package config has specific primitives for loading application configurations. 2 | // 3 | // Primitives: 4 | // - Application should have struct for containing configuration. E.g. refer 5 | // trino-gateway/internal/config/config.go file. 6 | // - Application should have a directory holding default file and environment 7 | // specific file. E.g. refer trino-gateway/configs/* directory. 8 | // 9 | // Usage: 10 | // - E.g. NewDefaultConfig().Load("dev", &config), where config is a struct 11 | // where configuration gets unmarshalled into. 12 | package config 13 | 14 | import ( 15 | "os" 16 | "path" 17 | "runtime" 18 | "strings" 19 | 20 | "github.com/spf13/viper" 21 | ) 22 | 23 | // Default options for configuration loading. 24 | const ( 25 | DefaultConfigType = "toml" 26 | DefaultConfigDir = "./config" 27 | DefaultConfigFileName = "default" 28 | WorkDirEnv = "WORKDIR" 29 | ) 30 | 31 | // Options is config options. 32 | type Options struct { 33 | configType string 34 | configPath string 35 | defaultConfigFileName string 36 | } 37 | 38 | // Config is a wrapper over a underlying config loader implementation. 39 | type Config struct { 40 | opts Options 41 | viper *viper.Viper 42 | } 43 | 44 | // NewDefaultOptions returns default options. 45 | // DISCLAIMER: This function is a bit hacky 46 | // This function expects an env $WORKDIR to 47 | // be set and reads configs from $WORKDIR/configs. 48 | // If $WORKDIR is not set. It uses the absolute path wrt 49 | // the location of this file (config.go) to set configPath 50 | // to 2 levels up in viper (../../configs). 51 | // This function breaks if : 52 | // 1. $WORKDIR is set and configs dir not present in $WORKDIR 53 | // 2. $WORKDIR is not set and ../../configs is not present 54 | // 3. $WORKDIR is not set and runtime absolute path of configs 55 | // is different than build time path as runtime.Caller() evaluates 56 | // only at build time 57 | func NewDefaultOptions() Options { 58 | var configPath string 59 | workDir := os.Getenv(WorkDirEnv) 60 | if workDir != "" { 61 | configPath = path.Join(workDir, DefaultConfigDir) 62 | } else { 63 | _, thisFile, _, _ := runtime.Caller(1) 64 | configPath = path.Join(path.Dir(thisFile), "../../"+DefaultConfigDir) 65 | } 66 | return NewOptions(DefaultConfigType, configPath, DefaultConfigFileName) 67 | } 68 | 69 | // NewOptions returns new Options struct. 70 | func NewOptions(configType string, configPath string, defaultConfigFileName string) Options { 71 | return Options{configType, configPath, defaultConfigFileName} 72 | } 73 | 74 | // NewDefaultConfig returns new config struct with default options. 75 | func NewDefaultConfig() *Config { 76 | return NewConfig(NewDefaultOptions()) 77 | } 78 | 79 | // NewConfig returns new config struct. 80 | func NewConfig(opts Options) *Config { 81 | return &Config{opts, viper.New()} 82 | } 83 | 84 | // Load reads environment specific configurations and along with the defaults 85 | // unmarshalls into config. 86 | func (c *Config) Load(env string, config interface{}) error { 87 | if err := c.loadByConfigName(c.opts.defaultConfigFileName, config); err != nil { 88 | return err 89 | } 90 | return c.loadByConfigName(env, config) 91 | } 92 | 93 | // loadByConfigName reads configuration from file and unmarshalls into config. 94 | func (c *Config) loadByConfigName(configName string, config interface{}) error { 95 | c.viper.SetEnvPrefix(strings.ToUpper("trino-gateway")) 96 | c.viper.SetConfigName(configName) 97 | c.viper.SetConfigType(c.opts.configType) 98 | c.viper.AddConfigPath(c.opts.configPath) 99 | c.viper.AutomaticEnv() 100 | c.viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) 101 | if err := c.viper.ReadInConfig(); err != nil { 102 | return err 103 | } 104 | return c.viper.Unmarshal(config) 105 | } 106 | -------------------------------------------------------------------------------- /pkg/config/config_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "os" 5 | "strings" 6 | "testing" 7 | "time" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | type TestConfig struct { 13 | Title string 14 | Db TestDbConfig 15 | } 16 | 17 | type TestDbConfig struct { 18 | Dialect string 19 | Protocol string 20 | Host string 21 | Port int 22 | Username string 23 | Password string 24 | SslMode string 25 | Name string 26 | MaxOpenConnections int 27 | MaxIdleConnections int 28 | ConnectionMaxLifetime time.Duration 29 | } 30 | 31 | func TestLoadConfig(t *testing.T) { 32 | var c TestConfig 33 | 34 | key := strings.ToUpper("trino-gateway") + "_DB_PASSWORD" 35 | os.Setenv(key, "envpass") 36 | err := NewConfig(NewOptions("toml", "./testdata", "default")).Load("default", &c) 37 | assert.Nil(t, err) 38 | // Asserts that default value exists. 39 | assert.Equal(t, "mysql", c.Db.Dialect) 40 | // Asserts that application environment specific value got overridden. 41 | assert.Equal(t, 10, c.Db.MaxOpenConnections) 42 | // Asserts that environment variable was honored. 43 | assert.Equal(t, "envpass", c.Db.Password) 44 | } 45 | -------------------------------------------------------------------------------- /pkg/config/testdata/default.toml: -------------------------------------------------------------------------------- 1 | title = "Default config file for testing." 2 | 3 | [db] 4 | dialect = "mysql" 5 | protocol = "tcp" 6 | host = "localhost" 7 | port = 3306 8 | username = "trino-gateway" 9 | password = "trino-gateway" 10 | sslMode = "require" 11 | name = "trino-gateway" 12 | maxOpenConnections = 10 13 | maxIdleConnections = 10 14 | connectionMaxLifetime = 0 15 | -------------------------------------------------------------------------------- /pkg/fetcher/fetcher.go: -------------------------------------------------------------------------------- 1 | package fetcher 2 | 3 | const MaxLimit = 2000 4 | 5 | // Request for fetching multiple entities 6 | type FetchMultipleRequest struct { 7 | EntityName string 8 | Filter map[string]interface{} 9 | Pagination Pagination 10 | TimeRange TimeRange 11 | IsTrashed bool 12 | HasCreatedAt bool 13 | } 14 | 15 | type IFetchMultipleRequest interface { 16 | GetEntityName() string 17 | GetFilter() map[string]interface{} 18 | GetPagination() IPagination 19 | GetTimeRange() ITimeRange 20 | GetTrashed() bool 21 | ContainsCreatedAt() bool 22 | } 23 | 24 | // GetEntityName : name of the entity to be fetched. 25 | func (fr FetchMultipleRequest) GetEntityName() string { 26 | return fr.EntityName 27 | } 28 | 29 | // GetFilter : filter conditions for fetching the entities. 30 | func (fr FetchMultipleRequest) GetFilter() map[string]interface{} { 31 | return fr.Filter 32 | } 33 | 34 | // GetPagination : skip and limit values for multiple fetch. 35 | func (fr FetchMultipleRequest) GetPagination() IPagination { 36 | return fr.Pagination 37 | } 38 | 39 | // GetTimeRange : time range for multiple fetch query. 40 | func (fr FetchMultipleRequest) GetTimeRange() ITimeRange { 41 | return fr.TimeRange 42 | } 43 | 44 | // GetTrashed : fetch soft deleted entities too if true. 45 | func (fr FetchMultipleRequest) GetTrashed() bool { 46 | return fr.IsTrashed 47 | } 48 | 49 | // ContainsCreatedAt: account_balance table doesn't have create_at field in capital_loc.. 50 | func (fr FetchMultipleRequest) ContainsCreatedAt() bool { 51 | return fr.HasCreatedAt 52 | } 53 | 54 | // skip and limit values for multiple fetch. 55 | type Pagination struct { 56 | Limit int 57 | Skip int 58 | } 59 | 60 | type IPagination interface { 61 | GetLimit() int 62 | GetOffset() int 63 | } 64 | 65 | // GetLimit : get limit value for pagination. 66 | func (p Pagination) GetLimit() int { 67 | if p.Limit > MaxLimit { 68 | return MaxLimit 69 | } 70 | return p.Limit 71 | } 72 | 73 | // GetOffset : get skip value for pagination. 74 | func (p Pagination) GetOffset() int { 75 | return p.Skip 76 | } 77 | 78 | // Time range for the entity fetch request. 79 | type TimeRange struct { 80 | From int64 81 | To int64 82 | } 83 | 84 | type ITimeRange interface { 85 | GetTo() int64 86 | GetFrom() int64 87 | } 88 | 89 | // GetFrom : from timestamp. 90 | func (tr TimeRange) GetFrom() int64 { 91 | return tr.From 92 | } 93 | 94 | // GetTo : to timestamp. 95 | func (tr TimeRange) GetTo() int64 { 96 | return tr.To 97 | } 98 | 99 | // Response for multiple fetch request. 100 | type FetchMultipleResponse struct { 101 | entities map[string]interface{} 102 | } 103 | 104 | type IFetchMultipleResponse interface { 105 | GetEntities() interface{} 106 | } 107 | 108 | // GetEntities : returns the fetched entities. 109 | func (fr FetchMultipleResponse) GetEntities() interface{} { 110 | return fr.entities 111 | } 112 | 113 | // Single entity fetch request. 114 | type FetchRequest struct { 115 | EntityName string 116 | ID string 117 | IsTrashed bool 118 | HasCreatedAt bool 119 | } 120 | 121 | type IFetchRequest interface { 122 | GetID() string 123 | GetTrashed() bool 124 | GetEntityName() string 125 | } 126 | 127 | // GetID : Id of the entity to be fetched. 128 | func (fr FetchRequest) GetID() string { 129 | return fr.ID 130 | } 131 | 132 | // GetEntityName : Name of the entity to be fetched. 133 | func (fr FetchRequest) GetEntityName() string { 134 | return fr.EntityName 135 | } 136 | 137 | // GetTrashed : Include soft deleted records if true. 138 | func (fr FetchRequest) GetTrashed() bool { 139 | return fr.IsTrashed 140 | } 141 | 142 | type FetchResponse struct { 143 | entity interface{} 144 | } 145 | 146 | // Response for single entity fetch 147 | type IFetchResponse interface { 148 | GetEntity() interface{} 149 | } 150 | 151 | // Return fetched entity. 152 | func (fr FetchResponse) GetEntity() interface{} { 153 | return fr.entity 154 | } 155 | -------------------------------------------------------------------------------- /pkg/fetcher/manager.go: -------------------------------------------------------------------------------- 1 | package fetcher 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "reflect" 7 | "sync" 8 | 9 | "gorm.io/gorm" 10 | ) 11 | 12 | // Client for fetcher package. 13 | type client struct { 14 | db *gorm.DB 15 | entities entityList 16 | } 17 | 18 | // IClient exposed methods of fetcher package. 19 | type IClient interface { 20 | Register(name string, entity IModel, entities interface{}) error 21 | IsEntityRegistered(name string) bool 22 | GetEntityList() []string 23 | Fetch(ctx context.Context, req IFetchRequest) (IFetchResponse, error) 24 | FetchMultiple(ctx context.Context, req IFetchMultipleRequest) (IFetchMultipleResponse, error) 25 | } 26 | 27 | // The entities to be fetched should implement IModel. 28 | type IModel interface { 29 | TableName() string 30 | } 31 | 32 | type entityType struct { 33 | model IModel 34 | models interface{} 35 | } 36 | 37 | type entityList struct { 38 | sync.Mutex 39 | list map[string]entityType 40 | } 41 | 42 | // New : returns a new fetcher client. 43 | func New(db *gorm.DB) IClient { 44 | return &client{ 45 | db: db, 46 | entities: entityList{ 47 | list: make(map[string]entityType), 48 | }, 49 | } 50 | } 51 | 52 | // Register : register the entity with the fetcher. 53 | // An entity can be registered only once. Only the 54 | // entities registered can be fetched using fetcher client. 55 | func (c *client) Register(name string, model IModel, models interface{}) error { 56 | c.entities.Lock() 57 | defer c.entities.Unlock() 58 | 59 | if _, ok := c.entities.list[name]; ok { 60 | return fmt.Errorf("entity %s name already registered", name) 61 | } 62 | c.entities.list[name] = entityType{ 63 | model: model, 64 | models: models, 65 | } 66 | return nil 67 | } 68 | 69 | func (c *client) IsEntityRegistered(name string) bool { 70 | if _, ok := c.entities.list[name]; ok { 71 | return true 72 | } 73 | return false 74 | } 75 | 76 | // GetEntityList : returns the list of entities registered with fetcher package. 77 | func (c *client) GetEntityList() []string { 78 | list := make([]string, 0, len(c.entities.list)) 79 | 80 | for k := range c.entities.list { 81 | list = append(list, k) 82 | } 83 | 84 | return list 85 | } 86 | 87 | // FetchMultiple : returns multiple record of the entities based on the fetch multiple request parameters. 88 | func (c *client) FetchMultiple(_ context.Context, req IFetchMultipleRequest) (IFetchMultipleResponse, error) { 89 | var ( 90 | dataTypes entityType 91 | ok bool 92 | ) 93 | 94 | if dataTypes, ok = c.entities.list[req.GetEntityName()]; !ok { 95 | return nil, fmt.Errorf("entity %s not registered with fetcher", req.GetEntityName()) 96 | } 97 | 98 | models := clone(dataTypes.models) 99 | var query *gorm.DB 100 | if req.ContainsCreatedAt() { 101 | query = c.db. 102 | Order("created_at DESC"). 103 | Limit(req.GetPagination().GetLimit()). 104 | Offset(req.GetPagination().GetOffset()) 105 | 106 | if req.GetTimeRange().GetFrom() != 0 && req.GetTimeRange().GetTo() != 0 { 107 | query = query.Where("created_at between ? and ?", req.GetTimeRange().GetFrom(), req.GetTimeRange().GetTo()) 108 | } 109 | } else { 110 | query = c.db. 111 | Limit(req.GetPagination().GetLimit()). 112 | Offset(req.GetPagination().GetOffset()) 113 | } 114 | 115 | if req.GetTrashed() { 116 | query = query.Unscoped() 117 | } 118 | 119 | if len(req.GetFilter()) >= 1 { 120 | query = query.Where(req.GetFilter()) 121 | } 122 | query = query.Find(models) 123 | return FetchMultipleResponse{ 124 | entities: map[string]interface{}{ 125 | req.GetEntityName(): models, 126 | }, 127 | }, query.Error 128 | } 129 | 130 | // Fetch : returns single entity from the using id. 131 | func (c *client) Fetch(_ context.Context, req IFetchRequest) (IFetchResponse, error) { 132 | var ( 133 | dataTypes entityType 134 | ok bool 135 | ) 136 | 137 | if dataTypes, ok = c.entities.list[req.GetEntityName()]; !ok { 138 | return nil, fmt.Errorf("entity %s not registered with fetcher", req.GetEntityName()) 139 | } 140 | 141 | model := clone(dataTypes.model) 142 | query := c.db.Where("id = ?", req.GetID()) 143 | 144 | if req.GetTrashed() { 145 | query = query.Unscoped() 146 | } 147 | 148 | query = query.First(model) 149 | return FetchResponse{ 150 | entity: model, 151 | }, query.Error 152 | } 153 | 154 | func clone(src interface{}) interface{} { 155 | return reflect. 156 | New(reflect.TypeOf(src).Elem()). 157 | Interface() 158 | } 159 | -------------------------------------------------------------------------------- /pkg/logger/entry_test.go: -------------------------------------------------------------------------------- 1 | package logger_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/razorpay/trino-gateway/pkg/logger" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestEntryWithError(t *testing.T) { 13 | assert := assert.New(t) 14 | 15 | defer func() { 16 | logger.ErrorKey = "error" 17 | }() 18 | 19 | err := fmt.Errorf("doomed here %d", 1234) 20 | 21 | config := logger.Config{ 22 | LogLevel: logger.Debug, 23 | } 24 | 25 | lgr, lErr := logger.NewLogger(config) 26 | 27 | assert.Nil(lErr) 28 | 29 | entry := logger.NewEntry(lgr) 30 | 31 | assert.Equal(err.Error(), entry.WithError(err).Data["error"]) 32 | 33 | logger.ErrorKey = "err" 34 | assert.Equal(err.Error(), entry.WithError(err).Data["err"]) 35 | } 36 | 37 | func TestEntryWithContext(t *testing.T) { 38 | assert := assert.New(t) 39 | ctx := context.WithValue(context.Background(), "foo", "bar") 40 | 41 | config := logger.Config{ 42 | LogLevel: logger.Debug, 43 | } 44 | 45 | lgr, err := logger.NewLogger(config) 46 | 47 | assert.Nil(err) 48 | 49 | entry := logger.NewEntry(lgr) 50 | 51 | assert.Equal("bar", entry.WithContext(ctx, []string{"foo"}).Data["foo"]) 52 | } 53 | -------------------------------------------------------------------------------- /pkg/logger/logger_test.go: -------------------------------------------------------------------------------- 1 | package logger_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/razorpay/trino-gateway/pkg/logger" 7 | ) 8 | 9 | func BenchmarkNewLogger(b *testing.B) { 10 | config := logger.Config{ 11 | LogLevel: logger.Debug, 12 | } 13 | lgr, err := logger.NewLogger(config) 14 | if err != nil { 15 | b.Errorf("NewLogger :%v", err) 16 | } 17 | /*zt := reflect.TypeOf(&(logger.ZapLogger{})) 18 | if !reflect.DeepEqual(lg, zt) { 19 | t.Errorf("NewLogger() = %v, want %v", lg, zt) 20 | }*/ 21 | defaultLogger := lgr.WithFields(map[string]interface{}{"key1": "value1"}) 22 | for n := 0; n < 1000; n++ { 23 | b.Run("defaultlogger", func(b *testing.B) { 24 | defaultLogger.Debug("I have the default values for Key1 and Value1") 25 | }) 26 | } 27 | } 28 | 29 | func LogTest(b *testing.B) { 30 | config := logger.Config{ 31 | LogLevel: logger.Debug, 32 | } 33 | lgr, err := logger.NewLogger(config) 34 | if err != nil { 35 | b.Errorf("NewLogger :%v", err) 36 | } 37 | defaultLogger := lgr.WithFields(map[string]interface{}{"key1": "value1"}) 38 | for n := 0; n < b.N; n++ { 39 | defaultLogger.Debug("I have the default values for Key1 and Value1") 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /pkg/spine/datatype/validation.go: -------------------------------------------------------------------------------- 1 | package datatype 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "regexp" 8 | "strconv" 9 | 10 | validation "github.com/go-ozzo/ozzo-validation/v4" 11 | "github.com/gobuffalo/nulls" 12 | ) 13 | 14 | const ( 15 | // regular expression to validation rzp standard ID 16 | RegexRZPID = `^[a-zA-Z0-9]{14}$` 17 | 18 | // regular expression to validation unix timestamp 19 | RegexUnixTimestamp = `^([\d]{10}|0)$` 20 | 21 | // regular expression to validation basic string 22 | // consisting alpha, number, _, - and white space 23 | RegexBasicString = `^[a-zA-Z0-9_\-\s]*$` 24 | ) 25 | 26 | // ValidateNullableInt64 checks NullableInt64 against the rules provided 27 | func ValidateNullableInt64(value nulls.Int64, rules ...validation.RuleFunc) validation.RuleFunc { 28 | return func(value interface{}) error { 29 | v := value.(nulls.Int64) 30 | for _, f := range rules { 31 | if err := f(v.Int64); err != nil { 32 | return err 33 | } 34 | } 35 | 36 | return nil 37 | } 38 | } 39 | 40 | // ValidateNullableString checks NullableString against the rules provided 41 | func ValidateNullableString(value nulls.String, rules ...validation.RuleFunc) validation.RuleFunc { 42 | return func(value interface{}) error { 43 | v := value.(nulls.String) 44 | for _, f := range rules { 45 | if err := f(v.String); err != nil { 46 | return err 47 | } 48 | } 49 | 50 | return nil 51 | } 52 | } 53 | 54 | // IsRZPID checks the given string is a valid 14-char Razorpay ID 55 | func IsRZPID(value interface{}) error { 56 | return isValidString(value, RegexRZPID) 57 | } 58 | 59 | // IsTimestamp will validate if the value is a valid unix timestamp or not 60 | func IsTimestamp(value interface{}) error { 61 | if value == nil { 62 | return nil 63 | } 64 | 65 | return MatchRegex(fmt.Sprintf("%v", value), RegexUnixTimestamp) 66 | } 67 | 68 | // MatchRegex checks if given input matches a given regex or not 69 | func MatchRegex(value string, regex string) error { 70 | if validString, err := regexp.Compile(regex); err != nil { 71 | return errors.New("invalid regex") 72 | } else if !validString.MatchString(value) { 73 | return errors.New("not a valid input") 74 | } 75 | 76 | return nil 77 | } 78 | 79 | // isValidString checks if given input matches a given regex or not 80 | func isValidString(value interface{}, regex string) error { 81 | // let the nil handled by required validation 82 | if value == nil { 83 | return nil 84 | } 85 | 86 | if str, err := isString(value); err != nil { 87 | return err 88 | } else if str == "" { 89 | return nil 90 | } else { 91 | return MatchRegex(str, regex) 92 | } 93 | } 94 | 95 | // isString checks if the given data is valid string or not 96 | func isString(value interface{}) (string, error) { 97 | if str, ok := value.(string); !ok { 98 | return "", errors.New("must be a string") 99 | } else { 100 | return str, nil 101 | } 102 | } 103 | 104 | // IsBasicString validates if the given value is basic string 105 | // consisting alphabet, number, _, - and white space 106 | func IsBasicString(value interface{}) error { 107 | return isValidString(value, RegexBasicString) 108 | } 109 | 110 | // IsInt64 checks if the given data is valid int64 or not 111 | func IsInt64(value interface{}) error { 112 | if _, ok := value.(int64); !ok { 113 | return errors.New("must be an int64 integer") 114 | } 115 | return nil 116 | } 117 | 118 | // IsJson validates if the given value is valid json 119 | func IsJson(value interface{}) error { 120 | if value == nil { 121 | return nil 122 | } 123 | 124 | if str, err := isString(value); err != nil { 125 | return err 126 | } else { 127 | input := []byte(str) 128 | var x struct{} 129 | if err := json.Unmarshal(input, &x); err != nil { 130 | return err 131 | } 132 | } 133 | 134 | return nil 135 | } 136 | 137 | // IsNumeric checks if the given string 's' is float, int, signed / unsigned, exponential 138 | // and returns true if it is valid .. 139 | func IsNumeric(value interface{}) error { 140 | valueStr, err := isString(value) 141 | if err != nil { 142 | return err 143 | } 144 | _, err = strconv.ParseFloat(valueStr, 64) 145 | 146 | return err 147 | } 148 | 149 | // IsBool checks if the given data is boolean or not 150 | func IsBool(value interface{}) error { 151 | switch value.(type) { 152 | case nil: 153 | return nil 154 | case bool: 155 | return nil 156 | default: 157 | return errors.New("must be a boolean") 158 | } 159 | } 160 | -------------------------------------------------------------------------------- /pkg/spine/db/db_test.go: -------------------------------------------------------------------------------- 1 | package db_test 2 | 3 | import ( 4 | "context" 5 | "regexp" 6 | "testing" 7 | "time" 8 | 9 | "gorm.io/plugin/dbresolver" 10 | 11 | "github.com/DATA-DOG/go-sqlmock" 12 | "github.com/stretchr/testify/assert" 13 | "gorm.io/driver/mysql" 14 | "gorm.io/gorm" 15 | 16 | "github.com/razorpay/trino-gateway/pkg/spine" 17 | "github.com/razorpay/trino-gateway/pkg/spine/db" 18 | ) 19 | 20 | type TestModel struct { 21 | spine.Model 22 | Name string `json:"name"` 23 | } 24 | 25 | func (t *TestModel) EntityName() string { 26 | return "model" 27 | } 28 | 29 | func (t *TestModel) TableName() string { 30 | return "model" 31 | } 32 | 33 | func (t *TestModel) GetID() string { 34 | return t.ID 35 | } 36 | 37 | func (t *TestModel) Validate() error { 38 | return nil 39 | } 40 | 41 | func (t *TestModel) SetDefaults() error { 42 | return nil 43 | } 44 | 45 | func TestGetConnectionPath(t *testing.T) { 46 | c := getDefaultConfig() 47 | // Asserts connection string for mysql dialect. 48 | assert.Equal(t, "user:pass@tcp(localhost:3307)/database?charset=utf8&parseTime=True&loc=Local", c.GetConnectionPath()) 49 | // Asserts connection string for postgres dialect. 50 | c.Dialect = "postgres" 51 | assert.Equal(t, "host=localhost port=3307 dbname=database sslmode=require user=user password=pass", c.GetConnectionPath()) 52 | 53 | // invalid dialect 54 | c.Dialect = "invalid" 55 | assert.Equal(t, "", c.GetConnectionPath()) 56 | } 57 | 58 | func TestNewDb(t *testing.T) { 59 | tests := []struct { 60 | name string 61 | err string 62 | config db.IConfigReader 63 | options func() ([]func(*db.DB) error, func()) 64 | }{ 65 | { 66 | name: "success", 67 | config: getDefaultConfig(), 68 | options: func() ([]func(*db.DB) error, func()) { 69 | conn, _, err := sqlmock.New() 70 | assert.Nil(t, err) 71 | return []func(*db.DB) error{ 72 | db.Dialector(getGormDialectorForMock(conn)), 73 | }, func() { _ = conn.Close() } 74 | }, 75 | }, 76 | { 77 | name: "invalid dialect", 78 | err: db.ErrorUndefinedDialect.Error(), 79 | config: &db.Config{ConnectionConfig: db.ConnectionConfig{Dialect: "invalid dialect"}}, 80 | options: func() ([]func(*db.DB) error, func()) { 81 | return []func(*db.DB) error{}, func() {} 82 | }, 83 | }, 84 | { 85 | name: "connect error: no mock", 86 | err: "dial tcp .+:3307: connect: connection refused", 87 | config: getDefaultConfig(), 88 | options: func() ([]func(*db.DB) error, func()) { 89 | return []func(*db.DB) error{}, func() {} 90 | }, 91 | }, 92 | } 93 | 94 | for _, testCase := range tests { 95 | t.Run(testCase.name, func(t *testing.T) { 96 | options, finish := testCase.options() 97 | defer finish() 98 | 99 | gdb, err := db.NewDb(testCase.config, options...) 100 | 101 | if testCase.err == "" { 102 | assert.Nil(t, err) 103 | assert.NotNil(t, gdb) 104 | } else { 105 | expr, e := regexp.Compile(testCase.err) 106 | assert.Nil(t, e) 107 | assert.NotNil(t, err) 108 | assert.Nil(t, gdb) 109 | assert.Regexp(t, expr, err.Error()) 110 | } 111 | }) 112 | } 113 | } 114 | 115 | func TestDB_Replica(t *testing.T) { 116 | defConn, _, err := sqlmock.New() 117 | assert.Nil(t, err) 118 | defer defConn.Close() 119 | 120 | replicaConn, replicaMock, err := sqlmock.New() 121 | assert.Nil(t, err) 122 | defer replicaConn.Close() 123 | 124 | sdb, err := db.NewDb(getDefaultConfig(), db.Dialector(getGormDialectorForMock(defConn)), db.GormConfig(&gorm.Config{})) 125 | assert.Nil(t, err) 126 | 127 | err = sdb.Replicas([]gorm.Dialector{getGormDialectorForMock(replicaConn)}, &db.ConnectionPoolConfig{ 128 | MaxOpenConnections: 5, 129 | MaxIdleConnections: 5, 130 | ConnectionMaxLifetime: 5 * time.Minute, 131 | }) 132 | assert.Nil(t, err) 133 | 134 | model := TestModel{} 135 | 136 | // 1. Test that with replica select query goes to replica 137 | replicaMock. 138 | ExpectQuery(regexp.QuoteMeta("SELECT * FROM `model` WHERE id = ? ORDER BY `model`.`id` LIMIT 1")). 139 | WithArgs("1").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "name1")) 140 | 141 | tx := sdb.Instance(context.TODO()).Where("id = ?", "1").First(&model) 142 | assert.Nil(t, tx.Error) 143 | assert.Equal(t, model.ID, "1") 144 | } 145 | 146 | func TestDB_WarmStorageDB(t *testing.T) { 147 | defConn, _, err := sqlmock.New() 148 | assert.Nil(t, err) 149 | defer defConn.Close() 150 | 151 | warmStorageConn, warmStorageMock, err := sqlmock.New() 152 | assert.Nil(t, err) 153 | defer warmStorageConn.Close() 154 | 155 | newDB, err := db.NewDb(getDefaultConfig(), db.Dialector(getGormDialectorForMock(defConn)), db.GormConfig(&gorm.Config{})) 156 | assert.Nil(t, err) 157 | 158 | err = newDB.WarmStorageDB([]gorm.Dialector{getGormDialectorForMock(warmStorageConn)}, &db.ConnectionPoolConfig{ 159 | MaxOpenConnections: 5, 160 | MaxIdleConnections: 5, 161 | ConnectionMaxLifetime: 5 * time.Minute, 162 | }) 163 | assert.Nil(t, err) 164 | 165 | model := TestModel{} 166 | 167 | // 1. Test that with warm storage select query goes to warm storage 168 | warmStorageMock. 169 | ExpectQuery(regexp.QuoteMeta("SELECT * FROM `model` WHERE id = ? ORDER BY `model`.`id` LIMIT 1")). 170 | WithArgs("1").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "name1")) 171 | 172 | tx := newDB.Instance(context.TODO()).Clauses(dbresolver.Use(db.WarmStorageDBResolverName)).Where("id = ?", "1").First(&model) 173 | assert.Nil(t, tx.Error) 174 | assert.Equal(t, model.ID, "1") 175 | } 176 | 177 | func TestDb_Alive(t *testing.T) { 178 | conn, _, err := sqlmock.New() 179 | assert.Nil(t, err) 180 | defer conn.Close() 181 | 182 | gdb, err := db.NewDb(getDefaultConfig(), db.Dialector(getGormDialectorForMock(conn))) 183 | assert.Nil(t, err) 184 | assert.NotNil(t, gdb) 185 | 186 | err = gdb.Alive() 187 | assert.Nil(t, err) 188 | } 189 | 190 | func TestDB_Instance(t *testing.T) { 191 | conn, _, err := sqlmock.New() 192 | assert.Nil(t, err) 193 | defer conn.Close() 194 | 195 | gdb1, err := db.NewDb(getDefaultConfig(), db.Dialector(getGormDialectorForMock(conn))) 196 | assert.Nil(t, err) 197 | assert.NotNil(t, gdb1) 198 | 199 | gdb2, err := db.NewDb(getDefaultConfig(), db.Dialector(getGormDialectorForMock(conn))) 200 | assert.Nil(t, err) 201 | assert.NotNil(t, gdb2) 202 | 203 | instance1 := gdb1.Instance(context.TODO()) 204 | assert.NotNil(t, instance1) 205 | 206 | instance2 := gdb2.Instance(context.TODO()) 207 | assert.NotNil(t, instance2) 208 | 209 | ctx := context.WithValue(context.TODO(), db.ContextKeyDatabase, instance2) 210 | tgdb := gdb1.Instance(ctx) 211 | assert.Equal(t, instance2, tgdb) 212 | } 213 | 214 | func getDefaultConfig() *db.Config { 215 | return &db.Config{ 216 | ConnectionPoolConfig: db.ConnectionPoolConfig{ 217 | MaxOpenConnections: 5, 218 | MaxIdleConnections: 5, 219 | ConnectionMaxLifetime: 5 * time.Minute, 220 | }, 221 | ConnectionConfig: db.ConnectionConfig{ 222 | Dialect: "mysql", 223 | Protocol: "tcp", 224 | URL: "localhost", 225 | Port: 3307, 226 | Username: "user", 227 | Password: "pass", 228 | SslMode: "require", 229 | Name: "database", 230 | }, 231 | } 232 | } 233 | 234 | func getGormDialectorForMock(conn gorm.ConnPool) gorm.Dialector { 235 | return mysql.New(mysql.Config{Conn: conn, SkipInitializeWithVersion: true}) 236 | } 237 | -------------------------------------------------------------------------------- /pkg/spine/errors.go: -------------------------------------------------------------------------------- 1 | package spine 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/go-sql-driver/mysql" 7 | 8 | "github.com/lib/pq" 9 | "gorm.io/gorm" 10 | ) 11 | 12 | const ( 13 | PQCodeUniqueViolation = "unique_violation" 14 | MySqlCodeUniqueViolation = 1062 15 | 16 | errDBError = "db_error" 17 | errNoRowAffected = "no_row_affected" 18 | errRecordNotFound = "record_not_found" 19 | errValidationFailure = "validation_failure" 20 | errUniqueConstraintViolation = "unique_constraint_violation" 21 | ) 22 | 23 | var ( 24 | DBError = errors.New(errDBError) 25 | NoRowAffected = errors.New(errNoRowAffected) 26 | RecordNotFound = errors.New(errRecordNotFound) 27 | ValidationFailure = errors.New(errValidationFailure) 28 | UniqueConstraintViolation = errors.New(errUniqueConstraintViolation) 29 | ) 30 | 31 | // GetDBError accepts db instance and the details 32 | // creates appropriate error based on the type of query result 33 | // if there is no error then returns nil 34 | func GetDBError(db *gorm.DB) error { 35 | if db.Error == nil { 36 | return nil 37 | } 38 | 39 | // check of error is specific to dialect 40 | if de, ok := DialectError(db); ok { 41 | // is the specific error is captured then return it 42 | // else try construct further errors 43 | if err := de.ConstructError(); err != nil { 44 | return err 45 | } 46 | } 47 | 48 | // Construct error based on type of db operation 49 | err := func() error { 50 | switch true { 51 | case errors.Is(db.Error, gorm.ErrRecordNotFound): 52 | return RecordNotFound 53 | 54 | default: 55 | return db.Error 56 | } 57 | }() 58 | 59 | // add specific details of error 60 | return err 61 | } 62 | 63 | // GetValidationError wraps the error and returns instance of ValidationError 64 | // if the provided error is nil then it just returns nil 65 | func GetValidationError(err error) error { 66 | if err != nil { 67 | return err 68 | } 69 | 70 | return nil 71 | } 72 | 73 | // DialectError returns true if the error is from dialect 74 | func DialectError(d *gorm.DB) (IDialectError, bool) { 75 | switch d.Error.(type) { 76 | case *pq.Error: 77 | return pqError{d.Error.(*pq.Error)}, true 78 | case *mysql.MySQLError: 79 | return mysqlError{d.Error.(*mysql.MySQLError)}, true 80 | default: 81 | return nil, false 82 | } 83 | } 84 | 85 | // IDialectError interface to handler dialect related errors 86 | type IDialectError interface { 87 | ConstructError() error 88 | } 89 | 90 | // pqError holds the error occurred by postgres 91 | type pqError struct { 92 | err *pq.Error 93 | } 94 | 95 | // ConstructError will create appropriate error based on dialect 96 | func (pqe pqError) ConstructError() error { 97 | switch pqe.err.Code.Name() { 98 | case PQCodeUniqueViolation: 99 | return pqe.err 100 | default: 101 | return nil 102 | } 103 | } 104 | 105 | type mysqlError struct { 106 | err *mysql.MySQLError 107 | } 108 | 109 | // ConstructError will create appropriate error based on dialect 110 | func (msqle mysqlError) ConstructError() error { 111 | switch msqle.err.Number { 112 | case MySqlCodeUniqueViolation: 113 | return msqle.err 114 | 115 | default: 116 | return nil 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /pkg/spine/model.go: -------------------------------------------------------------------------------- 1 | package spine 2 | 3 | import ( 4 | validation "github.com/go-ozzo/ozzo-validation/v4" 5 | "github.com/razorpay/trino-gateway/pkg/spine/datatype" 6 | ) 7 | 8 | const ( 9 | AttributeID = "id" 10 | AttributeCreatedAt = "created_at" 11 | AttributeUpdatedAt = "updated_at" 12 | AttributeDeletedAt = "deleted_at" 13 | ) 14 | 15 | type Model struct { 16 | ID string `json:"id"` 17 | CreatedAt int64 `json:"created_at"` 18 | UpdatedAt int64 `json:"updated_at"` 19 | } 20 | 21 | type IModel interface { 22 | TableName() string 23 | EntityName() string 24 | GetID() string 25 | Validate() error 26 | SetDefaults() error 27 | } 28 | 29 | // Validate validates base Model. 30 | func (m *Model) Validate() error { 31 | return GetValidationError( 32 | validation.ValidateStruct( 33 | m, 34 | validation.Field(&m.ID, validation.By(datatype.IsRZPID)), 35 | validation.Field(&m.CreatedAt, validation.By(datatype.IsTimestamp)), 36 | validation.Field(&m.UpdatedAt, validation.By(datatype.IsTimestamp)), 37 | ), 38 | ) 39 | } 40 | 41 | // GetID gets identifier of entity. 42 | func (m *Model) GetID() string { 43 | return m.ID 44 | } 45 | 46 | // GetCreatedAt gets created time of entity. 47 | func (m *Model) GetCreatedAt() int64 { 48 | return m.CreatedAt 49 | } 50 | 51 | // GetUpdatedAt gets last updated time of entity. 52 | func (m *Model) GetUpdatedAt() int64 { 53 | return m.UpdatedAt 54 | } 55 | -------------------------------------------------------------------------------- /pkg/spine/repository.go: -------------------------------------------------------------------------------- 1 | package spine 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/razorpay/trino-gateway/pkg/spine/db" 7 | "gorm.io/gorm/clause" 8 | "gorm.io/plugin/dbresolver" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | const updatedAtField = "updated_at" 14 | 15 | type Repo struct { 16 | Db *db.DB 17 | } 18 | 19 | // FindByID fetches the record which matches the ID provided from the entity defined by receiver 20 | // and the result will be loaded into receiver 21 | func (repo Repo) FindByID(ctx context.Context, receiver IModel, id string) error { 22 | q := repo.DBInstance(ctx).Where(AttributeID+" = ?", id).First(receiver) 23 | 24 | return GetDBError(q) 25 | } 26 | 27 | func (repo Repo) FindWithConditionByIDs( 28 | ctx context.Context, 29 | receivers interface{}, 30 | condition map[string]interface{}, 31 | ids []string, 32 | ) error { 33 | 34 | q := repo.DBInstance(ctx).Where(AttributeID+" in (?)", ids).Where(condition).Find(receivers) 35 | 36 | return GetDBError(q) 37 | } 38 | 39 | // FindByIDs fetches the all the records which matches the IDs provided from the entity defined by receivers 40 | // and the result will be loaded into receivers 41 | func (repo Repo) FindByIDs(ctx context.Context, receivers interface{}, ids []string) error { 42 | q := repo.DBInstance(ctx).Where(AttributeID+" in (?)", ids).Find(receivers) 43 | 44 | return GetDBError(q) 45 | } 46 | 47 | // Create inserts a new record in the entity defined by the receiver 48 | // all data filled in the receiver will inserted 49 | func (repo Repo) Create(ctx context.Context, receiver IModel) error { 50 | if err := receiver.SetDefaults(); err != nil { 51 | return err 52 | } 53 | 54 | if err := receiver.Validate(); err != nil { 55 | return err 56 | } 57 | 58 | q := repo.DBInstance(ctx).Create(receiver) 59 | 60 | return GetDBError(q) 61 | } 62 | 63 | // CreateInBatches insert the value in batches into database 64 | func (repo Repo) CreateInBatches(ctx context.Context, receivers interface{}, batchSize int) error { 65 | q := repo.DBInstance(ctx).CreateInBatches(receivers, batchSize) 66 | 67 | return GetDBError(q) 68 | } 69 | 70 | // Update will update the given receiver model with respect to primary key / id available in it. 71 | // If selective list is non empty, only those fields which are present in the list will be updated. 72 | // Note: When using selectiveList `updated_at` field need not be passed in the list. 73 | func (repo Repo) Update(ctx context.Context, receiver IModel, selectiveList ...string) error { 74 | if len(selectiveList) > 0 { 75 | selectiveList = append(selectiveList, updatedAtField) 76 | } 77 | return repo.updateSelective(ctx, receiver, selectiveList...) 78 | } 79 | 80 | // Delete deletes the given model 81 | // Soft or hard delete of model depends on the models implementation 82 | // if the model composites SoftDeletableModel then it'll be soft deleted 83 | func (repo Repo) Delete(ctx context.Context, receiver IModel) error { 84 | q := repo.DBInstance(ctx).Select(clause.Associations).Delete(receiver) 85 | 86 | return GetDBError(q) 87 | } 88 | 89 | func (repo Repo) ClearAssociations(ctx context.Context, receiver IModel, name string) error { 90 | err := repo.DBInstance(ctx).Model(receiver).Association(name).Clear() 91 | 92 | return err 93 | } 94 | 95 | func (repo Repo) ReplaceAssociations(ctx context.Context, receiver IModel, name string, ass interface{}) error { 96 | err := repo.DBInstance(ctx).Model(receiver).Association(name).Replace(ass) 97 | 98 | return err 99 | } 100 | 101 | // FineMany will fetch multiple records form the entity defined by receiver which matched the condition provided 102 | // note: this wont work for in clause. can be used only for `=` conditions 103 | func (repo Repo) FindMany( 104 | ctx context.Context, 105 | receivers interface{}, 106 | condition map[string]interface{}) error { 107 | 108 | q := repo.DBInstance(ctx).Where(condition).Find(receivers) 109 | 110 | return GetDBError(q) 111 | } 112 | 113 | // Preload preload associations with given conditions 114 | // repo.Preload(ctx, "Orders", "state NOT IN (?)", "cancelled").FindMany(ctx, &users) 115 | func (repo Repo) Preload(ctx context.Context, query string, args ...interface{}) *Repo { 116 | return &Repo{ 117 | Db: repo.Db.Preload(ctx, query, args), 118 | } 119 | } 120 | 121 | // Transaction will manage the execution inside a transactions 122 | // adds the txn db in the context for downstream use case 123 | func (repo Repo) Transaction(ctx context.Context, fc func(ctx context.Context) error) error { 124 | err := repo.DBInstance(ctx).Transaction(func(tx *gorm.DB) error { 125 | // This will ensure that when db.Instance(context) we return the txn on the context 126 | // & all repo queries are done on this txn. Refer usage in test. 127 | if err := fc(context.WithValue(ctx, db.ContextKeyDatabase, tx)); err != nil { 128 | return err 129 | } 130 | 131 | return GetDBError(tx) 132 | }) 133 | 134 | if err == nil { 135 | return nil 136 | } 137 | 138 | // tx.Commit can throw an error which will not be an IError 139 | if iErr, ok := err.(error); ok { 140 | return iErr 141 | } 142 | 143 | // use the default code and wrap err in internal 144 | return err 145 | } 146 | 147 | // IsTransactionActive returns true if a transaction is active 148 | func (repo Repo) IsTransactionActive(ctx context.Context) bool { 149 | _, ok := ctx.Value(db.ContextKeyDatabase).(*gorm.DB) 150 | return ok 151 | } 152 | 153 | // DBInstance returns gorm instance. 154 | // If replicas are specified, for Query, Row callback, will use replicas, unless Write mode specified. 155 | // For Raw callback, statements are considered read-only and will use replicas if the SQL starts with SELECT. 156 | // 157 | func (repo Repo) DBInstance(ctx context.Context) *gorm.DB { 158 | return repo.Db.Instance(ctx) 159 | } 160 | 161 | // WriteDBInstance returns a gorm instance of source/primary db connection. 162 | func (repo Repo) WriteDBInstance(ctx context.Context) *gorm.DB { 163 | return repo.DBInstance(ctx).Clauses(dbresolver.Write) 164 | } 165 | 166 | // WarmStorageDBInstance returns gorm instance of source/primary db connection. 167 | func (repo Repo) WarmStorageDBInstance(ctx context.Context) *gorm.DB { 168 | return repo.DBInstance(ctx).Clauses(dbresolver.Use(db.WarmStorageDBResolverName)) 169 | } 170 | 171 | // updateSelective will update the given receiver model with respect to primary key / id available in it. 172 | // If selective list is non empty, only those fields which are present in the list will be updated. 173 | // Note: When using selectiveList `updated_at` field also needs to be explicitly passed in the selectiveList. 174 | func (repo Repo) updateSelective(ctx context.Context, receiver IModel, selectiveList ...string) error { 175 | q := repo.DBInstance(ctx).Model(receiver) 176 | 177 | if len(selectiveList) > 0 { 178 | q = q.Select(selectiveList) 179 | } 180 | 181 | q = q.Updates(receiver) 182 | 183 | if q.RowsAffected == 0 { 184 | return NoRowAffected 185 | } 186 | 187 | return GetDBError(q) 188 | } 189 | -------------------------------------------------------------------------------- /scripts/compile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | go mod download 4 | 5 | protoc -I . \ 6 | -I "$GOPATH/pkg/mod/github.com/grpc-ecosystem/grpc-gateway/v2@$(go list -m -mod=mod -u github.com/grpc-ecosystem/grpc-gateway/v2 | awk '{print $2}')" \ 7 | --openapiv2_opt logtostderr=true \ 8 | --openapiv2_opt generate_unbound_methods=true \ 9 | --openapiv2_out ./third_party/swaggerui \ 10 | --twirp_out=. \ 11 | --go_out=. \ 12 | rpc/gateway/service.proto 13 | 14 | go mod vendor 15 | -------------------------------------------------------------------------------- /scripts/coverage.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/razorpay/trino-gateway/3f8ceab54b91b8c2b64a9339fb4c82c4432d6655/scripts/coverage.sh -------------------------------------------------------------------------------- /scripts/docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $@ == "up" ]]; then 4 | docker-compose -f build/docker/dev/docker-compose.yml up -d --build 5 | else 6 | docker-compose -f build/docker/dev/docker-compose.yml down 7 | fi -------------------------------------------------------------------------------- /scripts/run-example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Create dummy setup 4 | curl --request POST \ 5 | --url http://localhost:8000/twirp/razorpay.gateway.BackendApi/CreateOrUpdateBackend \ 6 | --header 'Content-Type: application/json' \ 7 | --header 'X-Auth-Key: test123' \ 8 | --data '{ 9 | "hostname": "docker.for.mac.localhost:37000", 10 | "scheme": "https", 11 | "id": "dev", 12 | "external_url": "docker.for.mac.localhost:37000", 13 | "is_enabled": true, 14 | "uptime_schedule": "* * * * *", 15 | "cluster_load": 0, 16 | "threshold_cluster_load": 0, 17 | "stats_updated_at": "0" 18 | }' 19 | 20 | curl --request POST \ 21 | --url http://localhost:8000/twirp/razorpay.gateway.GroupApi/CreateOrUpdateGroup \ 22 | --header 'Content-Type: application/json' \ 23 | --header 'X-Auth-Key: test123' \ 24 | --data '{ 25 | "id": "dev", 26 | "backends": ["dev"], 27 | "strategy": "RANDOM", 28 | "last_routed_backend": "dev", 29 | "is_enabled": true 30 | }' 31 | 32 | curl --request POST \ 33 | --url http://localhost:8000/twirp/razorpay.gateway.PolicyApi/CreateOrUpdatePolicy \ 34 | --header 'Content-Type: application/json' \ 35 | --header 'X-Auth-Key: test123' \ 36 | --data '{ 37 | "id": "dev", 38 | "rule": { 39 | "type": "listening_port", 40 | "value": "8080" 41 | }, 42 | "group": "dev", 43 | "fallback_group": "dev", 44 | "is_enabled": true, 45 | "is_auth_delegated": false, 46 | "set_request_source": "localDev" 47 | }' 48 | 49 | 50 | -------------------------------------------------------------------------------- /scripts/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | go mod download 4 | 5 | go install github.com/golang/protobuf/protoc-gen-go 6 | go install github.com/gopherjs/gopherjs 7 | go install github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-openapiv2 8 | go install github.com/twitchtv/twirp/protoc-gen-twirp 9 | -------------------------------------------------------------------------------- /third_party/swaggerui/favicon-16x16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/razorpay/trino-gateway/3f8ceab54b91b8c2b64a9339fb4c82c4432d6655/third_party/swaggerui/favicon-16x16.png -------------------------------------------------------------------------------- /third_party/swaggerui/favicon-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/razorpay/trino-gateway/3f8ceab54b91b8c2b64a9339fb4c82c4432d6655/third_party/swaggerui/favicon-32x32.png -------------------------------------------------------------------------------- /third_party/swaggerui/index.css: -------------------------------------------------------------------------------- 1 | html { 2 | box-sizing: border-box; 3 | overflow: -moz-scrollbars-vertical; 4 | overflow-y: scroll; 5 | } 6 | 7 | *, 8 | *:before, 9 | *:after { 10 | box-sizing: inherit; 11 | } 12 | 13 | body { 14 | margin: 0; 15 | background: #fafafa; 16 | } 17 | -------------------------------------------------------------------------------- /third_party/swaggerui/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Swagger UI 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /third_party/swaggerui/oauth2-redirect.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Swagger UI: OAuth2 Redirect 5 | 6 | 7 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /third_party/swaggerui/swagger-initializer.js: -------------------------------------------------------------------------------- 1 | window.onload = function() { 2 | // 3 | 4 | // the following lines will be replaced by docker/configurator, when it runs in a docker-container 5 | window.ui = SwaggerUIBundle({ 6 | url: "./rpc/gateway/service.swagger.json", 7 | dom_id: '#swagger-ui', 8 | deepLinking: true, 9 | presets: [ 10 | SwaggerUIBundle.presets.apis, 11 | SwaggerUIStandalonePreset 12 | ], 13 | plugins: [ 14 | SwaggerUIBundle.plugins.DownloadUrl, 15 | HideInfoUrlPartsPlugin, 16 | HideOperationsUntilAuthorizedPlugin 17 | ], 18 | layout: "StandaloneLayout" 19 | }); 20 | 21 | // 22 | }; 23 | 24 | const HideInfoUrlPartsPlugin = () => { 25 | return { 26 | wrapComponents: { 27 | InfoUrl: () => () => null, 28 | // InfoBasePath: () => () => null, // this hides the `Base Url` part too, if you want that 29 | } 30 | } 31 | } 32 | 33 | const HideOperationsUntilAuthorizedPlugin = function() { 34 | return { 35 | wrapComponents: { 36 | operation: (Ori, system) => (props) => { 37 | const isOperationSecured = !!props.operation.get("security").size 38 | const isOperationAuthorized = props.operation.get("isAuthorized") 39 | 40 | if(!isOperationSecured || isOperationAuthorized) { 41 | return system.React.createElement(Ori, props) 42 | } 43 | return null 44 | } 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /tools.go: -------------------------------------------------------------------------------- 1 | //go:build tools 2 | // +build tools 3 | 4 | package tools 5 | 6 | import ( 7 | // _ "github.com/elliots/protoc-gen-twirp_swagger" 8 | _ "github.com/golang/protobuf/protoc-gen-go" 9 | _ "github.com/gopherjs/gopherjs" 10 | _ "github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-openapiv2" 11 | _ "github.com/twitchtv/twirp/protoc-gen-twirp" 12 | ) 13 | -------------------------------------------------------------------------------- /web/frontend/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/razorpay/trino-gateway/3f8ceab54b91b8c2b64a9339fb4c82c4432d6655/web/frontend/favicon.ico -------------------------------------------------------------------------------- /web/frontend/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | --------------------------------------------------------------------------------