├── .github └── workflows │ ├── checks.yml │ ├── release-tee.yml │ └── release.yml ├── .gitignore ├── .golangci.yml ├── Dockerfile ├── Dockerfile.tee ├── LICENSE ├── Makefile ├── README.md ├── docs └── loadtest │ ├── README.md │ ├── script.js │ └── transactions.json ├── go.mod ├── go.sum ├── main.go ├── server ├── consts.go ├── errors.go ├── http_logger.go ├── node.go ├── node_notee.go ├── node_tee.go ├── node_test.go ├── nodepool.go ├── nodepool_test.go ├── queue.go ├── queue_test.go ├── redis.go ├── redis_test.go ├── server.go ├── server_test.go ├── types.go ├── utils.go ├── webserver.go └── webserver_test.go ├── staticcheck.conf └── testutils ├── mockserver.go ├── mockserver_test.go └── types.go /.github/workflows/checks.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - main 5 | pull_request: 6 | 7 | name: Checks 8 | jobs: 9 | lint: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout code 13 | uses: actions/checkout@v4 14 | 15 | - name: Install Go 16 | uses: actions/setup-go@v5 17 | with: 18 | go-version: ^1.22 19 | id: go 20 | 21 | - name: Install gofumpt 22 | run: go install mvdan.cc/gofumpt@v0.4.0 23 | 24 | - name: Install staticcheck 25 | run: go install honnef.co/go/tools/cmd/staticcheck@v0.6.1 26 | 27 | # - name: Install golangci-lint 28 | # run: go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.51.2 29 | 30 | - name: Lint 31 | run: make lint 32 | 33 | - name: Build 34 | run: make build build-tee 35 | 36 | - name: Ensure go mod tidy runs without changes 37 | run: | 38 | go mod tidy 39 | git diff-index HEAD 40 | git diff-index --quiet HEAD 41 | 42 | test: 43 | runs-on: ubuntu-latest 44 | steps: 45 | - name: Checkout code 46 | uses: actions/checkout@v4 47 | 48 | - name: Install Go 49 | uses: actions/setup-go@v5 50 | with: 51 | go-version: ^1.22 52 | id: go 53 | 54 | - name: Test 55 | run: make test 56 | -------------------------------------------------------------------------------- /.github/workflows/release-tee.yml: -------------------------------------------------------------------------------- 1 | name: ReleaseTEE 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | jobs: 9 | docker-image: 10 | name: Publish Docker TEE Image 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Checkout sources 15 | uses: actions/checkout@v4 16 | 17 | - name: Get tag version 18 | run: echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV 19 | 20 | - name: Print version 21 | run: | 22 | echo $RELEASE_VERSION 23 | echo ${{ env.RELEASE_VERSION }} 24 | 25 | - name: Set up QEMU 26 | uses: docker/setup-qemu-action@v2 27 | 28 | - name: Set up Docker Buildx 29 | uses: docker/setup-buildx-action@v2 30 | 31 | - name: Extract metadata (tags, labels) for Docker 32 | id: meta 33 | uses: docker/metadata-action@v4 34 | with: 35 | images: flashbots/prio-load-balancer 36 | flavor: | 37 | suffix=-tee,onlatest=true 38 | tags: | 39 | type=sha 40 | type=pep440,pattern={{version}} 41 | type=pep440,pattern={{major}}.{{minor}} 42 | type=raw,value=latest,enable=${{ !contains(env.RELEASE_VERSION, '-') }} 43 | 44 | - name: Login to DockerHub 45 | uses: docker/login-action@v2 46 | with: 47 | username: ${{ secrets.DOCKERHUB_USERNAME }} 48 | password: ${{ secrets.DOCKERHUB_TOKEN }} 49 | 50 | - name: Build and push 51 | uses: docker/build-push-action@v3 52 | with: 53 | context: . 54 | file: Dockerfile.tee 55 | push: true 56 | build-args: | 57 | VERSION=${{ env.RELEASE_VERSION }} 58 | platforms: linux/amd64 59 | tags: ${{ steps.meta.outputs.tags }} 60 | labels: ${{ steps.meta.outputs.labels }} 61 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | jobs: 9 | docker-image: 10 | name: Publish Docker Image 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Checkout sources 15 | uses: actions/checkout@v4 16 | 17 | - name: Get tag version 18 | run: echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV 19 | 20 | - name: Print version 21 | run: | 22 | echo $RELEASE_VERSION 23 | echo ${{ env.RELEASE_VERSION }} 24 | 25 | - name: Set up QEMU 26 | uses: docker/setup-qemu-action@v2 27 | 28 | - name: Set up Docker Buildx 29 | uses: docker/setup-buildx-action@v2 30 | 31 | - name: Extract metadata (tags, labels) for Docker 32 | id: meta 33 | uses: docker/metadata-action@v4 34 | with: 35 | images: flashbots/prio-load-balancer 36 | tags: | 37 | type=sha 38 | type=pep440,pattern={{version}} 39 | type=pep440,pattern={{major}}.{{minor}} 40 | type=raw,value=latest,enable=${{ !contains(env.RELEASE_VERSION, '-') }} 41 | 42 | - name: Login to DockerHub 43 | uses: docker/login-action@v2 44 | with: 45 | username: ${{ secrets.DOCKERHUB_USERNAME }} 46 | password: ${{ secrets.DOCKERHUB_TOKEN }} 47 | 48 | - name: Build and push 49 | uses: docker/build-push-action@v3 50 | with: 51 | context: . 52 | push: true 53 | build-args: | 54 | VERSION=${{ env.RELEASE_VERSION }} 55 | platforms: linux/amd64,linux/arm64,linux/arm/v6 56 | tags: ${{ steps.meta.outputs.tags }} 57 | labels: ${{ steps.meta.outputs.labels }} 58 | 59 | github-release: 60 | runs-on: ubuntu-latest 61 | steps: 62 | - name: Checkout sources 63 | uses: actions/checkout@v2 64 | 65 | - name: Create release 66 | id: create_release 67 | uses: actions/create-release@v1 68 | env: 69 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 70 | with: 71 | tag_name: ${{ github.ref }} 72 | release_name: ${{ github.ref }} 73 | draft: true 74 | prerelease: false 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /build 2 | /rpc-endpoint 3 | /tmp 4 | .env* 5 | /prio-load-balancer 6 | /prio-load-balancer 7 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | linters: 2 | enable-all: true 3 | disable: 4 | - exhaustruct 5 | - funlen 6 | - gochecknoglobals 7 | - gochecknoinits 8 | - gocritic 9 | - godot 10 | - godox 11 | - gomnd 12 | - lll 13 | - nlreturn 14 | - nonamedreturns 15 | - nosnakecase 16 | - paralleltest 17 | - testpackage 18 | - varnamelen 19 | - wrapcheck 20 | - wsl 21 | 22 | # 23 | # Maybe fix later: 24 | # 25 | - cyclop 26 | - gocognit 27 | - goconst 28 | - gosec 29 | - ireturn 30 | - noctx 31 | - tagliatelle 32 | 33 | # 34 | # Disabled because of generics: 35 | # 36 | - contextcheck 37 | - rowserrcheck 38 | - sqlclosecheck 39 | - structcheck 40 | - wastedassign 41 | 42 | # 43 | # Disabled because deprecated: 44 | # 45 | - deadcode 46 | - exhaustivestruct 47 | - golint 48 | - ifshort 49 | - interfacer 50 | - maligned 51 | - scopelint 52 | - varcheck 53 | 54 | linters-settings: 55 | gofumpt: 56 | extra-rules: true 57 | govet: 58 | enable-all: true 59 | disable: 60 | - fieldalignment 61 | - shadow 62 | 63 | output: 64 | print-issued-lines: true 65 | sort-results: true 66 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:1 2 | FROM golang:1.22 as builder 3 | ARG VERSION 4 | WORKDIR /build 5 | ADD . /build/ 6 | RUN --mount=type=cache,target=/root/.cache/go-build CGO_ENABLED=0 GOOS=linux go build -trimpath -ldflags "-s -X main.version=$VERSION" -v -o prio-load-balancer main.go 7 | 8 | FROM scratch 9 | WORKDIR /app 10 | COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ 11 | COPY --from=builder /build/prio-load-balancer /app/prio-load-balancer 12 | ENV LISTEN_ADDR=":8080" 13 | EXPOSE 8080 14 | CMD ["/app/prio-load-balancer"] 15 | -------------------------------------------------------------------------------- /Dockerfile.tee: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:1 2 | FROM golang:1.22 as builder 3 | ARG VERSION 4 | WORKDIR /build 5 | ADD . /build/ 6 | RUN --mount=type=cache,target=/root/.cache/go-build GOOS=linux go build --tags tee -trimpath -ldflags "-s -X main.version=$VERSION" -v -o prio-load-balancer main.go 7 | 8 | FROM ubuntu:20.04 as repos 9 | RUN apt-get update && \ 10 | DEBIAN_FRONTEND=noninteractive apt-get install -y curl && \ 11 | curl -fsSLo /usr/share/keyrings/gramine-keyring.gpg https://packages.gramineproject.io/gramine-keyring.gpg && \ 12 | echo 'deb [arch=amd64 signed-by=/usr/share/keyrings/gramine-keyring.gpg] https://packages.gramineproject.io/ focal main' > /etc/apt/sources.list.d/gramine.list && \ 13 | curl -fsSLo /usr/share/keyrings/intel-sgx-deb.key https://download.01.org/intel-sgx/sgx_repo/ubuntu/intel-sgx-deb.key && \ 14 | echo 'deb [arch=amd64 signed-by=/usr/share/keyrings/intel-sgx-deb.key] https://download.01.org/intel-sgx/sgx_repo/ubuntu focal main' > /etc/apt/sources.list.d/intel-sgx.list && \ 15 | curl -fsSLo /usr/share/keyrings/microsoft.key https://packages.microsoft.com/keys/microsoft.asc && \ 16 | echo 'deb [arch=amd64 signed-by=/usr/share/keyrings/microsoft.key] https://packages.microsoft.com/ubuntu/20.04/prod focal main' > /etc/apt/sources.list.d/microsoft.list 17 | 18 | FROM ubuntu:20.04 19 | 20 | RUN apt-get update && \ 21 | apt-get install -y ca-certificates && \ 22 | rm -rf /var/lib/apt/lists/* 23 | 24 | COPY --from=repos /usr/share/keyrings/gramine-keyring.gpg /usr/share/keyrings/gramine-keyring.gpg 25 | COPY --from=repos /usr/share/keyrings/intel-sgx-deb.key /usr/share/keyrings/intel-sgx-deb.key 26 | COPY --from=repos /usr/share/keyrings/microsoft.key /usr/share/keyrings/microsoft.key 27 | COPY --from=repos /etc/apt/sources.list.d/gramine.list /etc/apt/sources.list.d/gramine.list 28 | COPY --from=repos /etc/apt/sources.list.d/intel-sgx.list /etc/apt/sources.list.d/intel-sgx.list 29 | COPY --from=repos /etc/apt/sources.list.d/microsoft.list /etc/apt/sources.list.d/microsoft.list 30 | 31 | RUN apt-get update && \ 32 | DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 33 | az-dcap-client \ 34 | libsgx-urts \ 35 | libsgx-dcap-quote-verify && \ 36 | DEBIAN_FRONTEND=noninteractive apt-get download -y gramine gramine-ratls-dcap && \ 37 | apt-get clean autoclean && apt-get autoremove --yes && \ 38 | dpkg -i --force-depends *.deb && \ 39 | rm *.deb && \ 40 | rm -rf /var/lib/apt/lists/* 41 | 42 | WORKDIR /app 43 | COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ 44 | COPY --from=builder /build/prio-load-balancer /app/prio-load-balancer 45 | ENV LISTEN_ADDR=":8080" 46 | EXPOSE 8080 47 | CMD ["/app/prio-load-balancer"] 48 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021-2022 Flashbots 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all v build test clean lint cover cover-html docker-image 2 | 3 | VERSION := $(shell git describe --tags --always --dirty="-dev") 4 | 5 | all: clean build 6 | 7 | v: 8 | @echo "Version: ${VERSION}" 9 | 10 | run-dev: 11 | go run . -mock-node -log-prod 12 | 13 | build: 14 | go build -trimpath -ldflags "-s -X main.version=${VERSION}" -v -o prio-load-balancer main.go 15 | 16 | build-tee: 17 | go build -tags tee -trimpath -ldflags "-s -X main.version=${VERSION}" -v -o prio-load-balancer main.go 18 | 19 | clean: 20 | rm -rf prio-load-balancer build/ 21 | 22 | test: 23 | go test ./... 24 | 25 | lint: 26 | gofmt -d -s . 27 | gofumpt -d -extra . 28 | go vet ./... 29 | go vet --tags=tee ./... 30 | staticcheck ./... 31 | # golangci-lint run 32 | 33 | lt: lint test 34 | 35 | lint-strict: lint 36 | gofumpt -d -extra . 37 | golangci-lint run 38 | 39 | fmt: 40 | gofmt -s -w . 41 | gofumpt -extra -w . 42 | gci write . 43 | go mod tidy 44 | 45 | cover: 46 | go test -coverprofile=/tmp/go-prio-lb.cover.tmp ./... 47 | go tool cover -func /tmp/go-prio-lb.cover.tmp 48 | unlink /tmp/go-prio-lb.cover.tmp 49 | 50 | cover-html: 51 | go test -coverprofile=/tmp/go-prio-lb.cover.tmp ./... 52 | go tool cover -html=/tmp/go-prio-lb.cover.tmp 53 | unlink /tmp/go-prio-lb.cover.tmp 54 | 55 | docker-image: 56 | DOCKER_BUILDKIT=1 docker build --platform linux/amd64 --build-arg VERSION=${VERSION} . -t prio-load-balancer 57 | 58 | docker-image-tee: 59 | DOCKER_BUILDKIT=1 docker build --platform linux/amd64 --build-arg VERSION=${VERSION} . -f Dockerfile.tee -t prio-load-balancer 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transparent JSON-RPC proxy and load balancer 2 | 3 | [![Goreport status](https://goreportcard.com/badge/github.com/flashbots/prio-load-balancer)](https://goreportcard.com/report/github.com/flashbots/prio-load-balancer) 4 | [![Test status](https://github.com/flashbots/prio-load-balancer/workflows/Checks/badge.svg)](https://github.com/flashbots/prio-load-balancer/actions?query=workflow%3A%22Checks%22) 5 | [![Docker hub](https://badgen.net/docker/size/flashbots/prio-load-balancer?icon=docker&label=image)](https://hub.docker.com/r/flashbots/prio-load-balancer/tags) 6 | 7 | **With priority queues, retries, good logging and metrics, and even SGX/SEV attestation support.** 8 | 9 | Queues: 10 | 11 | 1. low-prio 12 | 2. high-prio 13 | 3. fast-track 14 | 15 | Queueing: 16 | 17 | - All high-prio requests will be proxied before any of the low-prio queue 18 | - [N](https://github.com/flashbots/prio-load-balancer/blob/main/server/consts.go#L20) fast-tracked requests get processed for every 1 high-prio request 19 | 20 | Further notes: 21 | 22 | - A _node_ represents one JSON-RPC endpoint (i.e. geth instance) 23 | - Each node spins up N workers, which proxy requests concurrently to the execution endpoint 24 | - You can add/remove nodes through a JSON API without restarting the server 25 | - Each node starts the default number of workers, but you can also specify a custom number of workers by adding `?_workers=` to the node URL 26 | - It's possible to tweak [a few knobs](/server/consts.go) 27 | 28 | 29 | --- 30 | 31 | ### Application structure and request flow 32 | 33 | ![App structure and request flow](https://user-images.githubusercontent.com/116939/202170917-bcd98c98-f40e-4025-8084-06adec27ff96.png) 34 | 35 | ### Example logs 36 | 37 | At the end of a request: 38 | 39 | ```json 40 | { 41 | "level": "info", 42 | "ts": 1685122704.4079978, 43 | "caller": "server/webserver.go:154", 44 | "msg": "Request completed", 45 | "service": "validation-queue", 46 | "durationMs": 144, 47 | "requestIsHighPrio": true, 48 | "requestIsFastTrack": false, 49 | "payloadSize": 209394, 50 | "statusCode": 200, 51 | "nodeURI": "http://validation-1.internal:8545", 52 | "requestTries": 1, 53 | "queueItems": 11, 54 | "queueItemsFastTrack": 0, 55 | "queueItemsHighPrio": 7, 56 | "queueItemsLowPrio": 4 57 | } 58 | ``` 59 | 60 | Full request cycle: 61 | 62 | ```json 63 | // request getting added to the queue 64 | {"level":"info","ts":1685126514.8569882,"caller":"server/webserver.go:112","msg":"Request added to queue. prioQueue size:","service":"validation-queue","requestIsHighPrio":true,"requestIsFastTrack":true,"fastTrack":1,"highPrio":0,"lowPrio":0} 65 | 66 | // completed request 67 | {"level":"info","ts":1685126514.9174724,"caller":"server/webserver.go:154","msg":"Request completed","service":"validation-queue","durationMs":78,"requestIsHighPrio":true,"requestIsFastTrack":true,"payloadSize":121291,"statusCode":200,"nodeURI":"http://validation-2.internal:8545","requestTries":1,"queueItems":0,"queueItemsFastTrack":0,"queueItemsHighPrio":0,"queueItemsLowPrio":0} 68 | 69 | // http server logs 70 | {"level":"info","ts":1685126514.9175296,"caller":"server/http_logger.go:54","msg":"http: POST /sim 200","service":"validation-queue","status":200,"method":"POST","path":"/sim","duration":0.078877451} 71 | ``` 72 | 73 | --- 74 | 75 | ## Getting started 76 | 77 | Docker images are available at https://hub.docker.com/r/flashbots/prio-load-balancer 78 | 79 | #### Run the program 80 | 81 | ```bash 82 | # Run with a mock execution backend and debug output 83 | go run . -mock-node # text logging 84 | go run . -mock-node -log-prod # json logging 85 | 86 | # low-prio queue request 87 | curl -d '{"jsonrpc":"2.0","method":"eth_callBundle","params":[],"id":1}' localhost:8080 88 | 89 | # high-prio queue request 90 | curl -H 'X-High-Priority: true' -d '{"jsonrpc":"2.0","method":"eth_callBundle","params":[],"id":1}' localhost:8080 91 | 92 | # fast-track queue request 93 | curl -H 'X-Fast-Track: true' -d '{"jsonrpc":"2.0","method":"eth_callBundle","params":[],"id":1}' localhost:8080 94 | 95 | # adding a custom request ID 96 | curl -H 'X-Request-ID: yourLogID' -d '{"jsonrpc":"2.0","method":"eth_callBundle","params":[],"id":1}' localhost:8080 97 | 98 | # Get execution nodes 99 | curl localhost:8080/nodes 100 | 101 | # Add a execution node 102 | curl -d '{"uri":"http://foo"}' localhost:8080/nodes 103 | 104 | # Add a execution node with custom number of workers 105 | curl -d '{"uri":"http://foo?_workers=8"}' localhost:8080/nodes 106 | 107 | # Remove a execution node 108 | curl -X DELETE -d '{"uri":"http://foo"}' localhost:8080/nodes 109 | curl -X DELETE -d '{"uri":"http://localhost:8095"}' localhost:8080/nodes 110 | ``` 111 | 112 | Note: there's a bunch of constants that can be configured with env vars in [server/consts.go](server/consts.go). 113 | 114 | #### Node selection 115 | 116 | * Redis is used as source of truth for which execution nodes to use. 117 | * If you restart with a different set of configured nodes (i.e. in env vars), the previous nodes will still be in Redis and still be used by the load balancer. 118 | * See the commands in the readme above on how to get the nodes it uses, and how to add/remove nodes. 119 | 120 | #### Test, lint, build 121 | 122 | ```bash 123 | # lint & staticcheck (staticcheck.io) 124 | make lint 125 | 126 | # run tests 127 | make test 128 | 129 | # test coverage 130 | make cover 131 | make cover-html 132 | 133 | # build 134 | make build 135 | ``` 136 | 137 | #### Node TEE attestation via TLS 138 | ``` 139 | # build prio-load-balancer with SGX and SEV support 140 | make build-tee 141 | ``` 142 | 143 | > **IMPORTANT:** SGX and SEV attestation support requires additional dependencies. See [Dockerfile.tee](Dockerfile.tee) for details. 144 | 145 | #### SEV Node aTLS attestation 146 | 147 | ``` 148 | # base64 encode the VM measurements 149 | 150 | MEASUREMENTS=$(cat << EOF | gzip | basenc --base64url -w0 151 | { 152 | "1": { 153 | "expected": "3d458cfe55cc03ea1f443f1562beec8df51c75e14a9fcf9a7234a13f198e7969", 154 | "warnOnly": true 155 | }, 156 | "2": { 157 | "expected": "3d458cfe55cc03ea1f443f1562beec8df51c75e14a9fcf9a7234a13f198e7969", 158 | "warnOnly": true 159 | }, 160 | "3": { 161 | "expected": "3d458cfe55cc03ea1f443f1562beec8df51c75e14a9fcf9a7234a13f198e7969", 162 | "warnOnly": true 163 | }, 164 | "4": { 165 | "expected": "82736cdd6b4f3c718bf969b545eaaa6eb3f1e6d229ad9712e6a4ddf431418ab7", 166 | "warnOnly": false 167 | }, 168 | "5": { 169 | "expected": "54c04bcd7cf8adadafee915bf325f92d958050c14e086c1e180258113d376c1a", 170 | "warnOnly": true 171 | }, 172 | "6": { 173 | "expected": "9319868ef4dad6a79117f14b9ac1870ccf5f9d178b39a3fd84e6230fa93a7993", 174 | "warnOnly": true 175 | }, 176 | "7": { 177 | "expected": "32fe42b385b47cb22c906b8a7e4f134e9f2270818f90e94072d1101ef72f1c00", 178 | "warnOnly": true 179 | }, 180 | "8": { 181 | "expected": "0000000000000000000000000000000000000000000000000000000000000000", 182 | "warnOnly": false 183 | }, 184 | "9": { 185 | "expected": "0000000000000000000000000000000000000000000000000000000000000000", 186 | "warnOnly": false 187 | }, 188 | "11": { 189 | "expected": "0000000000000000000000000000000000000000000000000000000000000000", 190 | "warnOnly": false 191 | }, 192 | "12": { 193 | "expected": "f1a142c53586e7e2223ec74e5f4d1a4942956b1fd9ac78fafcdf85117aa345da", 194 | "warnOnly": false 195 | }, 196 | "13": { 197 | "expected": "0000000000000000000000000000000000000000000000000000000000000000", 198 | "warnOnly": false 199 | }, 200 | "14": { 201 | "expected": "e3991b7ddd47be7e92726a832d6874c5349b52b789fa0db8b558c69fea29574e", 202 | "warnOnly": true 203 | }, 204 | "15": { 205 | "expected": "0000000000000000000000000000000000000000000000000000000000000000", 206 | "warnOnly": false 207 | } 208 | } 209 | EOF 210 | ) 211 | ``` 212 | 213 | ``` 214 | # Add the SEV execution node 215 | curl -d "{\"uri\":\"https://SEV_${MEASUREMENTS}@foo\"}" localhost:8080/nodes 216 | ``` 217 | 218 | Execution nodes running within SEV and providing attestation consumables via constellations aTLS implementation are supported. The aTLS certificate of the execution node is automatically attested with the VM measurements which are submitted as part of the user part of the **node URI** (`SEV_`). You can read more about the attestation measurements in the [constellation docs](https://docs.edgeless.systems/constellation/architecture/attestation#runtime-measurements) 219 | 220 | #### SGX Node RA-TLS attestation 221 | ``` 222 | # Add an SGX execution node 223 | curl -d '{"uri":"https://SGX_@foo"}' localhost:8080/nodes 224 | ``` 225 | 226 | Execution nodes running within SGX and providing attestation consumables via RA-TLS are supported. The RA-TLS certificate of the execution node is automatically attested with the `MRENCLAVE` which is submitted as part of the user part of **node URI** (`SGX_`). 227 | 228 | --- 229 | 230 | ## Queue Benchmarks 231 | 232 | ``` 233 | goarch: amd64 234 | pkg: github.com/flashbots/prio-load-balancer/server 235 | cpu: Intel(R) Core(TM) i9-8950HK CPU @ 2.90GHz 236 | 237 | 1 worker, 10k tasks: 238 | BenchmarkPrioQueue-12 2338 492219 ns/op 298109 B/op 34 allocs/op 239 | 240 | 5 workers, 10k tasks: 241 | BenchmarkPrioQueueMultiReader-12 2690 596315 ns/op 292507 B/op 50 allocs/op 242 | 243 | 5 workers, 100k tasks: 244 | BenchmarkPrioQueueMultiReader-12 261 4637403 ns/op 4245243 B/op 66 allocs/op 245 | ``` 246 | 247 | --- 248 | 249 | ## Todo 250 | 251 | Possibly 252 | 253 | * Configurable redis prefix, to allow multiple sim-lbs per redis instance 254 | * Execution-node health checks (currently not implemented) 255 | 256 | --- 257 | 258 | ## Maintainers 259 | 260 | - [@metachris](https://twitter.com/metachris) 261 | 262 | --- 263 | 264 | ## License 265 | 266 | The code in this project is free software under the [MIT License](LICENSE). 267 | -------------------------------------------------------------------------------- /docs/loadtest/README.md: -------------------------------------------------------------------------------- 1 | https://k6.io/docs/getting-started/running-k6/ 2 | 3 | Update `blockNumber` in `script.js` 4 | 5 | ```bash 6 | # Run the script 1x 7 | k6 run script.js 8 | 9 | # Run 10 parallel request loops, for 30 sec total 10 | k6 run --vus 10 --duration 30s script.js 11 | ``` 12 | -------------------------------------------------------------------------------- /docs/loadtest/script.js: -------------------------------------------------------------------------------- 1 | import http from 'k6/http'; 2 | import { check } from "k6"; 3 | import { SharedArray } from 'k6/data'; 4 | 5 | const blockNumber = 14050699; 6 | 7 | const data = new SharedArray('txs', function () { 8 | return JSON.parse(open('./transactions.json')); 9 | }); 10 | 11 | const url = "YOUR_URL" 12 | 13 | export default function () { 14 | const idx = Math.floor(Math.random() * data.length) 15 | const isHighPrio = Math.random() < 0.5 16 | 17 | console.log("req tx:", idx, "highPrio:", isHighPrio) 18 | 19 | var params = { 20 | headers: { 21 | 'Content-Type': 'application/json', 22 | 'high_priority': isHighPrio 23 | }, 24 | }; 25 | 26 | const payload = { 27 | "jsonrpc": "2.0", 28 | "id": 1, 29 | "method": "eth_callBundle", 30 | "params": [ 31 | { 32 | "txs": [data[idx]], 33 | "blockNumber": `0x${(blockNumber + 10).toString(16)}`, 34 | "stateBlockNumber": "latest" 35 | } 36 | ] 37 | } 38 | 39 | const res = http.post(url, JSON.stringify(payload), params); 40 | const resData = res.json(); 41 | 42 | check(res, { "status is 200": (r) => r.status === 200 }); 43 | check(resData, { "response has no error": (d) => !d.error }); 44 | } 45 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/flashbots/prio-load-balancer 2 | 3 | go 1.20 4 | 5 | replace ( 6 | github.com/google/go-tpm => github.com/thomasten/go-tpm v0.0.0-20230222180349-bb3cc5560299 7 | github.com/google/go-tpm-tools => github.com/daniel-weisse/go-tpm-tools v0.0.0-20230105122812-f7474d459dfc 8 | ) 9 | 10 | require ( 11 | github.com/alicebob/miniredis v2.5.0+incompatible 12 | github.com/go-redis/redis/v8 v8.11.5 13 | github.com/gorilla/mux v1.8.0 14 | github.com/konvera/geth-sev v0.0.0-20230425080657-b02eb0266f3b 15 | github.com/konvera/gramine-ratls-golang v0.0.0-20230417022221-836955fa9223 16 | github.com/pkg/errors v0.9.1 17 | github.com/stretchr/testify v1.8.2 18 | go.uber.org/atomic v1.10.0 19 | go.uber.org/zap v1.24.0 20 | ) 21 | 22 | require ( 23 | code.cloudfoundry.org/clock v0.0.0-20180518195852-02e53af36e6c // indirect 24 | github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0 // indirect 25 | github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.2 // indirect 26 | github.com/Azure/azure-sdk-for-go/sdk/internal v1.2.0 // indirect 27 | github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights v1.0.0 // indirect 28 | github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v4 v4.1.0 // indirect 29 | github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v2 v2.1.0 // indirect 30 | github.com/AzureAD/microsoft-authentication-library-for-go v0.9.0 // indirect 31 | github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect 32 | github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d // indirect 33 | github.com/blang/semver v3.5.1+incompatible // indirect 34 | github.com/cespare/xxhash/v2 v2.2.0 // indirect 35 | github.com/cyberphone/json-canonicalization v0.0.0-20210303052042-6bc126869bf4 // indirect 36 | github.com/davecgh/go-spew v1.1.1 // indirect 37 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 38 | github.com/edgelesssys/constellation/v2 v2.7.0 // indirect 39 | github.com/edgelesssys/go-azguestattestation v0.0.0-20230303085714-62ede861d33f // indirect 40 | github.com/go-chi/chi v4.1.2+incompatible // indirect 41 | github.com/go-jose/go-jose/v3 v3.0.0 // indirect 42 | github.com/go-openapi/analysis v0.21.4 // indirect 43 | github.com/go-openapi/errors v0.20.3 // indirect 44 | github.com/go-openapi/jsonpointer v0.19.6 // indirect 45 | github.com/go-openapi/jsonreference v0.20.1 // indirect 46 | github.com/go-openapi/loads v0.21.2 // indirect 47 | github.com/go-openapi/runtime v0.24.1 // indirect 48 | github.com/go-openapi/spec v0.20.7 // indirect 49 | github.com/go-openapi/strfmt v0.21.3 // indirect 50 | github.com/go-openapi/swag v0.22.3 // indirect 51 | github.com/go-openapi/validate v0.22.0 // indirect 52 | github.com/go-playground/locales v0.14.1 // indirect 53 | github.com/go-playground/universal-translator v0.18.1 // indirect 54 | github.com/go-playground/validator/v10 v10.11.2 // indirect 55 | github.com/gofrs/uuid v4.2.0+incompatible // indirect 56 | github.com/golang-jwt/jwt/v4 v4.5.0 // indirect 57 | github.com/golang/protobuf v1.5.2 // indirect 58 | github.com/gomodule/redigo v1.8.8 // indirect 59 | github.com/google/certificate-transparency-go v1.1.4 // indirect 60 | github.com/google/go-attestation v0.4.4-0.20221011162210-17f9c05652a9 // indirect 61 | github.com/google/go-containerregistry v0.13.0 // indirect 62 | github.com/google/go-sev-guest v0.4.1 // indirect 63 | github.com/google/go-tpm v0.3.3 // indirect 64 | github.com/google/go-tpm-tools v0.3.10 // indirect 65 | github.com/google/go-tspi v0.3.0 // indirect 66 | github.com/google/logger v1.1.1 // indirect 67 | github.com/google/trillian v1.5.1 // indirect 68 | github.com/google/uuid v1.3.0 // indirect 69 | github.com/hashicorp/go-cleanhttp v0.5.2 // indirect 70 | github.com/hashicorp/go-retryablehttp v0.7.2 // indirect 71 | github.com/inconshreveable/mousetrap v1.0.1 // indirect 72 | github.com/jedisct1/go-minisign v0.0.0-20211028175153-1c139d1cc84b // indirect 73 | github.com/josharian/intern v1.0.0 // indirect 74 | github.com/kylelemons/godebug v1.1.0 // indirect 75 | github.com/leodido/go-urn v1.2.2 // indirect 76 | github.com/letsencrypt/boulder v0.0.0-20221109233200-85aa52084eaf // indirect 77 | github.com/mailru/easyjson v0.7.7 // indirect 78 | github.com/microsoft/ApplicationInsights-Go v0.4.4 // indirect 79 | github.com/mitchellh/mapstructure v1.5.0 // indirect 80 | github.com/oklog/ulid v1.3.1 // indirect 81 | github.com/opencontainers/go-digest v1.0.0 // indirect 82 | github.com/opentracing/opentracing-go v1.2.0 // indirect 83 | github.com/pborman/uuid v1.2.1 // indirect 84 | github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect 85 | github.com/pmezard/go-difflib v1.0.0 // indirect 86 | github.com/sassoftware/relic v0.0.0-20210427151427-dfb082b79b74 // indirect 87 | github.com/secure-systems-lab/go-securesystemslib v0.5.0 // indirect 88 | github.com/siderolabs/talos/pkg/machinery v1.3.2 // indirect 89 | github.com/sigstore/rekor v1.0.1 // indirect 90 | github.com/sigstore/sigstore v1.6.0 // indirect 91 | github.com/spf13/afero v1.9.5 // indirect 92 | github.com/spf13/cobra v1.6.1 // indirect 93 | github.com/spf13/pflag v1.0.5 // indirect 94 | github.com/tent/canonical-json-go v0.0.0-20130607151641-96e4ba3a7613 // indirect 95 | github.com/theupdateframework/go-tuf v0.5.2 // indirect 96 | github.com/titanous/rocacheck v0.0.0-20171023193734-afe73141d399 // indirect 97 | github.com/transparency-dev/merkle v0.0.1 // indirect 98 | github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect 99 | go.mongodb.org/mongo-driver v1.10.0 // indirect 100 | go.uber.org/multierr v1.9.0 // indirect 101 | golang.org/x/crypto v0.6.0 // indirect 102 | golang.org/x/exp v0.0.0-20220823124025-807a23277127 // indirect 103 | golang.org/x/mod v0.8.0 // indirect 104 | golang.org/x/net v0.8.0 // indirect 105 | golang.org/x/sys v0.6.0 // indirect 106 | golang.org/x/term v0.6.0 // indirect 107 | golang.org/x/text v0.8.0 // indirect 108 | google.golang.org/genproto v0.0.0-20230320184635-7606e756e683 // indirect 109 | google.golang.org/grpc v1.53.0 // indirect 110 | google.golang.org/protobuf v1.29.1 // indirect 111 | gopkg.in/square/go-jose.v2 v2.6.0 // indirect 112 | gopkg.in/yaml.v2 v2.4.0 // indirect 113 | gopkg.in/yaml.v3 v3.0.1 // indirect 114 | ) 115 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "net/http" 6 | "os" 7 | "os/signal" 8 | "runtime" 9 | "strconv" 10 | "strings" 11 | "syscall" 12 | "time" 13 | 14 | "github.com/alicebob/miniredis" 15 | "github.com/flashbots/prio-load-balancer/server" 16 | "github.com/flashbots/prio-load-balancer/testutils" 17 | "go.uber.org/zap" 18 | ) 19 | 20 | var ( 21 | version = "dev" // is set during build process 22 | 23 | // Default values 24 | // defaultDebug = os.Getenv("DEBUG") == "1" 25 | defaultRedis = getEnv("REDIS_URI", "dev") 26 | defaultListenAddr = getEnv("LISTEN_ADDR", "localhost:8080") 27 | defaultlogProd = os.Getenv("LOG_PROD") == "1" 28 | defaultLogService = os.Getenv("LOG_SERVICE") 29 | defaultNodeWorkers = getEnvInt("NUM_NODE_WORKERS", 8) // number of maximum concurrent requests per node 30 | defaultNodes = os.Getenv("NODES") 31 | defaultBackends = os.Getenv("BACKENDS") 32 | 33 | // Flags 34 | httpAddrPtr = flag.String("http", defaultListenAddr, "http service address") 35 | // debugPtr = flag.Bool("debug", defaultDebug, "print debug output") 36 | nodeWorkersPtr = flag.Int("node-workers", defaultNodeWorkers, "number of concurrent workers per node") 37 | nodesPtr = flag.String("nodes", defaultNodes, "nodes to use (comma separated)") 38 | backendsPtr = flag.String("backends", defaultBackends, "backend nodes to use (comma separated URLs to proxy requests to)") 39 | redisPtr = flag.String("redis", defaultRedis, "redis URI ('dev' for built-in)") 40 | useMockNodePtr = flag.Bool("mock-node", false, "run a mock node backend") 41 | logProdPtr = flag.Bool("log-prod", defaultlogProd, "production logging") 42 | logServicePtr = flag.String("log-service", defaultLogService, "'service' tag to logs") 43 | ) 44 | 45 | func perr(err error) { 46 | if err != nil { 47 | panic(err) 48 | } 49 | } 50 | 51 | func main() { 52 | flag.Parse() 53 | 54 | // Setup logging 55 | var logger *zap.Logger 56 | if *logProdPtr { 57 | logger, _ = zap.NewProduction() 58 | } else { 59 | logger, _ = zap.NewDevelopment() 60 | } 61 | log := logger.Sugar() 62 | if *logServicePtr != "" { 63 | log = log.With("service", *logServicePtr) 64 | } 65 | log.Infow("Starting prio-load-balancer", "version", version) 66 | 67 | // Setup the redis connection 68 | if *redisPtr == "dev" { 69 | log.Info("Using integrated in-memory Redis instance") 70 | redisServer, err := miniredis.Run() 71 | perr(err) 72 | *redisPtr = redisServer.Addr() 73 | } 74 | 75 | serverOpts := server.ServerOpts{ 76 | Log: log, 77 | RedisURI: *redisPtr, 78 | WorkersPerNode: int32(*nodeWorkersPtr), 79 | HTTPAddrPtr: *httpAddrPtr, 80 | } 81 | 82 | srv, err := server.NewServer(serverOpts) 83 | perr(err) 84 | 85 | if *useMockNodePtr { 86 | addr := "localhost:8095" 87 | mockNodeBackend := testutils.NewMockNodeBackend() 88 | http.HandleFunc("/", mockNodeBackend.Handler) 89 | log.Info("Using mock node backend", "listenAddr", addr) 90 | go http.ListenAndServe(addr, nil) 91 | perr(srv.AddNode("http://" + addr)) 92 | 93 | // enable additional APIs in dev mode by default 94 | server.EnableErrorTestAPI = true // will be used later, in srv.Start() 95 | server.EnablePprof = true 96 | } 97 | 98 | if *nodesPtr != "" { 99 | for _, uri := range strings.Split(*nodesPtr, ",") { 100 | perr(srv.AddNode(uri)) 101 | } 102 | } 103 | 104 | if *backendsPtr != "" { 105 | for _, uri := range strings.Split(*backendsPtr, ",") { 106 | perr(srv.AddNode(uri)) 107 | } 108 | } 109 | 110 | go func() { // All 10 seconds: log stats 111 | for { 112 | time.Sleep(10 * time.Second) 113 | log.Infow("goroutines:", "numGoroutines", runtime.NumGoroutine()) 114 | lenFastTrack, lenHighPrio, lenLowPrio := srv.QueueSize() 115 | log.Infow("prioQueue size:", "fastTrack", lenFastTrack, "highPrio", lenHighPrio, "lowPrio", lenLowPrio) 116 | } 117 | }() 118 | 119 | // Handle shutdown gracefully 120 | go func() { 121 | exit := make(chan os.Signal, 1) 122 | signal.Notify(exit, os.Interrupt, syscall.SIGTERM) 123 | <-exit 124 | log.Info("Shutting down...") 125 | srv.Shutdown() 126 | }() 127 | 128 | // Log the current config 129 | server.LogConfig(log) 130 | 131 | // Start the server 132 | srv.Start() 133 | log.Info("bye") 134 | } 135 | 136 | func getEnv(key, defaultValue string) string { 137 | if value, ok := os.LookupEnv(key); ok { 138 | return value 139 | } 140 | return defaultValue 141 | } 142 | 143 | func getEnvInt(key string, defaultValue int) int { 144 | if value, ok := os.LookupEnv(key); ok { 145 | val, err := strconv.Atoi(value) 146 | if err == nil { 147 | return val 148 | } 149 | } 150 | return defaultValue 151 | } 152 | -------------------------------------------------------------------------------- /server/consts.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "os" 5 | "time" 6 | 7 | "go.uber.org/zap" 8 | ) 9 | 10 | var ( 11 | JobChannelBuffer = GetEnvInt("JOB_CHAN_BUFFER", 2) // buffer for JobC in backends (for transporting jobs from server -> backend node) 12 | RequestMaxTries = GetEnvInt("RETRIES_MAX", 3) // 3 tries means it will be retried 2 additional times, and on third error would fail 13 | PayloadMaxBytes = GetEnvInt("PAYLOAD_MAX_KB", 8192) * 1024 // Max payload size in bytes. If a payload sent to the webserver is larger, it returns "400 Bad Request". 14 | 15 | MaxQueueItemsFastTrack = GetEnvInt("ITEMS_FASTTRACK_MAX", 0) // Max number of items in fast-track queue. 0 means no limit. 16 | MaxQueueItemsHighPrio = GetEnvInt("ITEMS_HIGHPRIO_MAX", 0) // Max number of items in high-prio queue. 0 means no limit. 17 | MaxQueueItemsLowPrio = GetEnvInt("ITEMS_LOWPRIO_MAX", 0) // Max number of items in low-prio queue. 0 means no limit. 18 | 19 | // How often fast-track queue items should be popped before popping a high-priority item 20 | FastTrackPerHighPrio = GetEnvInt("ITEMS_FASTTRACK_PER_HIGHPRIO", 2) 21 | FastTrackDrainFirst = os.Getenv("FASTTRACK_DRAIN_FIRST") == "1" // whether to fully drain the fast-track queue first 22 | 23 | RequestTimeout = time.Duration(GetEnvInt("REQUEST_TIMEOUT", 5)) * time.Second // Time between creation and receive in the node worker, after which a SimRequest will not be processed anymore 24 | ServerJobSendTimeout = time.Duration(GetEnvInt("JOB_SEND_TIMEOUT", 2)) * time.Second // How long the server tries to send a job into the nodepool for processing 25 | ProxyRequestTimeout = time.Duration(GetEnvInt("REQUEST_PROXY_TIMEOUT", 3)) * time.Second // HTTP request timeout for proxy requests to the backend node 26 | 27 | RedisPrefix = GetEnv("REDIS_PREFIX", "prio-load-balancer:") // All redis keys will be prefixed with this 28 | EnableErrorTestAPI = os.Getenv("ENABLE_ERROR_TEST_API") == "1" // will enable /debug/testLogLevels which prints errors and ends with a panic (also enabled if mock-node is used) 29 | EnablePprof = os.Getenv("ENABLE_PPROF") == "1" // will enable /debug/pprof 30 | 31 | ProxyMaxIdleConns = GetEnvInt("ProxyMaxIdleConns", 100) 32 | ProxyMaxConnsPerHost = GetEnvInt("ProxyMaxConnsPerHost", 100) 33 | ProxyMaxIdleConnsPerHost = GetEnvInt("ProxyMaxIdleConnsPerHost", 100) 34 | ProxyIdleConnTimeout = time.Duration(GetEnvInt("ProxyIdleConnTimeout", 90)) * time.Second 35 | ) 36 | 37 | func LogConfig(log *zap.SugaredLogger) { 38 | log.Infow("config", 39 | "JobChannelBuffer", JobChannelBuffer, 40 | "RequestMaxTries", RequestMaxTries, 41 | "MaxQueueItemsHighPrio", MaxQueueItemsHighPrio, 42 | "MaxQueueItemsLowPrio", MaxQueueItemsLowPrio, 43 | "FastTrackPerHighPrio", FastTrackPerHighPrio, 44 | "FastTrackDrainFirst", FastTrackDrainFirst, 45 | "PayloadMaxBytes", PayloadMaxBytes, 46 | "RequestTimeout", RequestTimeout, 47 | "ServerJobSendTimeout", ServerJobSendTimeout, 48 | "ProxyRequestTimeout", ProxyRequestTimeout, 49 | "RedisPrefix", RedisPrefix, 50 | "EnableErrorTestAPI", EnableErrorTestAPI, 51 | "EnablePprof", EnablePprof, 52 | "ProxyMaxIdleConns", ProxyMaxIdleConns, 53 | "ProxyMaxConnsPerHost", ProxyMaxConnsPerHost, 54 | "ProxyMaxIdleConnsPerHost", ProxyMaxIdleConnsPerHost, 55 | "ProxyIdleConnTimeout", ProxyIdleConnTimeout, 56 | ) 57 | } 58 | -------------------------------------------------------------------------------- /server/errors.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import "errors" 4 | 5 | var ( 6 | ErrRequestTimeout = errors.New("request timeout hit before processing") 7 | ErrNodeTimeout = errors.New("node timeout") 8 | ErrNoNodesAvailable = errors.New("no nodes available") 9 | ) 10 | -------------------------------------------------------------------------------- /server/http_logger.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "runtime/debug" 7 | "time" 8 | 9 | "go.uber.org/zap" 10 | ) 11 | 12 | // responseWriter is a minimal wrapper for http.ResponseWriter that allows the 13 | // written HTTP status code to be captured for logging. 14 | type responseWriter struct { 15 | http.ResponseWriter 16 | status int 17 | wroteHeader bool 18 | } 19 | 20 | func wrapResponseWriter(w http.ResponseWriter) *responseWriter { 21 | return &responseWriter{ResponseWriter: w} 22 | } 23 | 24 | func (rw *responseWriter) Status() int { 25 | return rw.status 26 | } 27 | 28 | func (rw *responseWriter) WriteHeader(code int) { 29 | if rw.wroteHeader { 30 | return 31 | } 32 | 33 | rw.status = code 34 | rw.ResponseWriter.WriteHeader(code) 35 | rw.wroteHeader = true 36 | } 37 | 38 | // LoggingMiddleware logs the incoming HTTP request & its duration. 39 | func LoggingMiddleware(log *zap.SugaredLogger, next http.Handler) http.Handler { 40 | return http.HandlerFunc( 41 | func(w http.ResponseWriter, r *http.Request) { 42 | defer func() { 43 | if err := recover(); err != nil { 44 | w.WriteHeader(http.StatusInternalServerError) 45 | log.Errorw(fmt.Sprintf("http request panic: %s %s", r.Method, r.URL.EscapedPath()), 46 | "err", err, 47 | "trace", debug.Stack(), 48 | ) 49 | } 50 | }() 51 | start := time.Now() 52 | wrapped := wrapResponseWriter(w) 53 | next.ServeHTTP(wrapped, r) 54 | log.Infow(fmt.Sprintf("http: %s %s %d", r.Method, r.URL.EscapedPath(), wrapped.status), 55 | "status", wrapped.status, 56 | "method", r.Method, 57 | "path", r.URL.EscapedPath(), 58 | "duration", time.Since(start).Seconds(), 59 | ) 60 | }, 61 | ) 62 | } 63 | -------------------------------------------------------------------------------- /server/node.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "strconv" 10 | "sync/atomic" 11 | "time" 12 | 13 | "github.com/pkg/errors" 14 | "go.uber.org/zap" 15 | ) 16 | 17 | type Node struct { 18 | log *zap.SugaredLogger 19 | URI string 20 | AddedAt time.Time 21 | jobC chan *SimRequest 22 | numWorkers int32 23 | curWorkers int32 24 | cancelContext context.Context 25 | cancelFunc context.CancelFunc 26 | client *http.Client 27 | } 28 | 29 | func (n *Node) HealthCheck() error { 30 | payload := `{"jsonrpc":"2.0","method":"net_version","params":[],"id":123}` 31 | _, _, err := n.ProxyRequest(context.Background(), []byte(payload), 5*time.Second) 32 | return err 33 | } 34 | 35 | func (n *Node) startProxyWorker(id int32, cancelContext context.Context) { 36 | log := n.log.With( 37 | "uri", n.URI, 38 | "id", id, 39 | ) 40 | log.Infow("starting proxy node worker") 41 | atomic.AddInt32(&n.curWorkers, 1) 42 | defer atomic.AddInt32(&n.curWorkers, -1) 43 | 44 | for { 45 | select { 46 | case req := <-n.jobC: 47 | _log := log.With("reqID", req.ID) 48 | _log.Debug("processing request") 49 | 50 | if req.Cancelled { 51 | _log.Info("request was cancelled before processing") 52 | continue 53 | } 54 | 55 | if time.Since(req.CreatedAt) > RequestTimeout { 56 | _log.Info("request timed out before processing") 57 | req.SendResponse(SimResponse{Error: ErrRequestTimeout}) 58 | continue 59 | } 60 | 61 | req.Tries += 1 62 | timeBeforeProxy := time.Now().UTC() 63 | payload, statusCode, err := n.ProxyRequest(req.Context, req.Payload, ProxyRequestTimeout) 64 | requestDuration := time.Since(timeBeforeProxy) 65 | _log = _log.With("requestDurationUS", requestDuration.Microseconds()) 66 | if err != nil { 67 | // if not context deadline exceeded 68 | if errors.Is(err, context.DeadlineExceeded) { 69 | _log.Infow("node proxyRequest error: context deatline exeeded", "uri", n.URI, "error", err) 70 | } else { 71 | _log.Errorw("node proxyRequest error", "uri", n.URI, "error", err) 72 | } 73 | response := SimResponse{StatusCode: statusCode, Payload: payload, Error: err, ShouldRetry: true, NodeURI: n.URI} 74 | req.SendResponse(response) 75 | continue 76 | } 77 | 78 | // Send response 79 | _log.Debug("request processed, sending response") 80 | sent := req.SendResponse(SimResponse{Payload: payload, NodeURI: n.URI, SimDuration: requestDuration, SimAt: timeBeforeProxy}) 81 | if !sent { 82 | _log.Errorw("couldn't send node response to client (SendResponse returned false)", "secSinceRequestCreated", time.Since(req.CreatedAt).Seconds()) 83 | } 84 | 85 | case <-cancelContext.Done(): 86 | log.Infow("node worker stopped") 87 | return 88 | } 89 | } 90 | } 91 | 92 | // StartWorkers spawns the proxy workers in goroutines. Workers that are already running will be cancelled. 93 | func (n *Node) StartWorkers() { 94 | if n.cancelFunc != nil { 95 | n.cancelFunc() 96 | } 97 | 98 | n.cancelContext, n.cancelFunc = context.WithCancel(context.Background()) 99 | for i := int32(0); i < n.numWorkers; i++ { 100 | go n.startProxyWorker(i+1, n.cancelContext) 101 | } 102 | } 103 | 104 | func (n *Node) StopWorkers() { 105 | if n.cancelFunc != nil { 106 | n.cancelFunc() 107 | } 108 | } 109 | 110 | func (n *Node) StopWorkersAndWait() { 111 | n.StopWorkers() 112 | for { 113 | if n.curWorkers == 0 { 114 | break 115 | } 116 | time.Sleep(100 * time.Millisecond) 117 | } 118 | } 119 | 120 | func (n *Node) ProxyRequest(ctx context.Context, payload []byte, timeout time.Duration) (resp []byte, statusCode int, err error) { 121 | ctxx, cancel := context.WithTimeout(ctx, timeout) 122 | defer cancel() 123 | httpReq, err := http.NewRequestWithContext(ctxx, "POST", n.URI, bytes.NewBuffer(payload)) 124 | if err != nil { 125 | return resp, statusCode, errors.Wrap(err, "creating proxy request failed") 126 | } 127 | 128 | httpReq.Header.Set("Accept", "application/json") 129 | httpReq.Header.Set("Content-Type", "application/json") 130 | httpReq.Header.Set("Content-Length", strconv.Itoa(len(payload))) 131 | 132 | httpResp, err := n.client.Do(httpReq) 133 | if err != nil { 134 | return resp, statusCode, errors.Wrap(err, "proxying request failed") 135 | } 136 | 137 | statusCode = httpResp.StatusCode 138 | 139 | defer httpResp.Body.Close() 140 | httpRespBody, err := io.ReadAll(httpResp.Body) 141 | if err != nil { 142 | return resp, statusCode, errors.Wrap(err, "decoding proxying response failed") 143 | } 144 | 145 | if statusCode >= 400 { 146 | return httpRespBody, statusCode, fmt.Errorf("error in response - statusCode: %d / %s", statusCode, httpRespBody) 147 | } 148 | 149 | return httpRespBody, statusCode, nil 150 | } 151 | -------------------------------------------------------------------------------- /server/node_notee.go: -------------------------------------------------------------------------------- 1 | //go:build !tee 2 | // +build !tee 3 | 4 | package server 5 | 6 | import ( 7 | "net/http" 8 | "net/url" 9 | "strconv" 10 | "time" 11 | 12 | "go.uber.org/zap" 13 | ) 14 | 15 | func NewNode(log *zap.SugaredLogger, uri string, jobC chan *SimRequest, numWorkers int32) (*Node, error) { 16 | pURL, err := url.ParseRequestURI(uri) 17 | if err != nil { 18 | return nil, err 19 | } 20 | 21 | workersArg := pURL.Query().Get("_workers") 22 | if workersArg != "" { 23 | // set numWorkers from query param 24 | workersInt, err := strconv.Atoi(workersArg) 25 | if err != nil { 26 | log.Errorw("Error parsing workers query param", "err", err, "uri", uri) 27 | } else { 28 | log.Infow("Using custom number of workers", "workers", workersInt, "uri", uri) 29 | numWorkers = int32(workersInt) 30 | } 31 | } 32 | 33 | node := &Node{ 34 | log: log, 35 | URI: uri, 36 | AddedAt: time.Now(), 37 | jobC: jobC, 38 | numWorkers: numWorkers, 39 | client: &http.Client{ 40 | Timeout: ProxyRequestTimeout, 41 | Transport: &http.Transport{ 42 | MaxIdleConns: ProxyMaxIdleConns, 43 | MaxConnsPerHost: ProxyMaxConnsPerHost, 44 | MaxIdleConnsPerHost: ProxyMaxIdleConnsPerHost, 45 | IdleConnTimeout: ProxyIdleConnTimeout, 46 | }, 47 | }, 48 | } 49 | return node, nil 50 | } 51 | -------------------------------------------------------------------------------- /server/node_tee.go: -------------------------------------------------------------------------------- 1 | //go:build tee 2 | // +build tee 3 | 4 | package server 5 | 6 | import ( 7 | "bytes" 8 | "compress/gzip" 9 | "crypto/tls" 10 | "encoding/base64" 11 | "encoding/hex" 12 | "encoding/json" 13 | "fmt" 14 | "io/ioutil" 15 | "net/http" 16 | "net/url" 17 | "strconv" 18 | "strings" 19 | "time" 20 | 21 | "github.com/konvera/geth-sev/constellation/atls" 22 | "github.com/konvera/geth-sev/constellation/attestation/azure/snp" 23 | "github.com/konvera/geth-sev/constellation/config" 24 | ratls "github.com/konvera/gramine-ratls-golang" 25 | "go.uber.org/zap" 26 | ) 27 | 28 | func init() { 29 | err := ratls.InitRATLSLib(true, time.Hour, false) 30 | if err != nil { 31 | panic(err) 32 | } 33 | } 34 | 35 | type attestationLogger struct { 36 | log *zap.SugaredLogger 37 | } 38 | 39 | func (w attestationLogger) Infof(format string, args ...any) { 40 | w.log.Infow(fmt.Sprintf(format, args...)) 41 | } 42 | 43 | func (w attestationLogger) Warnf(format string, args ...any) { 44 | w.log.Warnw(fmt.Sprintf(format, args...)) 45 | } 46 | 47 | func NewNode(log *zap.SugaredLogger, uri string, jobC chan *SimRequest, numWorkers int32) (*Node, error) { 48 | client := http.Client{} 49 | pURL, err := url.ParseRequestURI(uri) 50 | if err != nil { 51 | return nil, err 52 | } 53 | username := pURL.User.Username() 54 | 55 | workersArg := pURL.Query().Get("_workers") 56 | if workersArg != "" { 57 | // set numWorkers from query param 58 | workersInt, err := strconv.Atoi(workersArg) 59 | if err != nil { 60 | log.Errorw("Error parsing workers query param", "err", err, "uri", uri) 61 | } else { 62 | log.Infow("Using custom number of workers", "workers", workersInt, "uri", uri) 63 | numWorkers = int32(workersInt) 64 | } 65 | } 66 | 67 | if strings.HasPrefix(username, "SGX_") { // SGX TLS config 68 | mrenclave, err := hex.DecodeString(strings.TrimPrefix(username, "SGX_")) 69 | if err != nil { 70 | return nil, err 71 | } 72 | verifyConnection := func(cs tls.ConnectionState) error { 73 | err := ratls.RATLSVerifyDer(cs.PeerCertificates[0].Raw, mrenclave, nil, nil, nil) 74 | return err 75 | } 76 | client = http.Client{ 77 | Transport: &http.Transport{ 78 | TLSClientConfig: &tls.Config{ 79 | InsecureSkipVerify: true, 80 | VerifyConnection: verifyConnection, 81 | }, 82 | }, 83 | } 84 | } else if strings.HasPrefix(username, "SEV_") { // SEV TLS config 85 | gzmeasurements, err := base64.URLEncoding.DecodeString(strings.TrimPrefix(username, "SEV_")) 86 | if err != nil { 87 | return nil, err 88 | } 89 | 90 | gzreader, err := gzip.NewReader(bytes.NewReader(gzmeasurements)) 91 | if err != nil { 92 | return nil, err 93 | } 94 | 95 | measurements, err := ioutil.ReadAll(gzreader) 96 | if err != nil { 97 | return nil, err 98 | } 99 | 100 | attConfig := config.DefaultForAzureSEVSNP() 101 | err = json.Unmarshal(measurements, &attConfig.Measurements) 102 | if err != nil { 103 | return nil, err 104 | } 105 | 106 | validators := []atls.Validator{snp.NewValidator(attConfig, attestationLogger{log})} 107 | tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, validators) 108 | if err != nil { 109 | return nil, err 110 | } 111 | client = http.Client{ 112 | Timeout: ProxyRequestTimeout, 113 | Transport: &http.Transport{ 114 | TLSClientConfig: tlsConfig, 115 | MaxIdleConns: ProxyMaxIdleConns, 116 | MaxConnsPerHost: ProxyMaxConnsPerHost, 117 | MaxIdleConnsPerHost: ProxyMaxIdleConnsPerHost, 118 | IdleConnTimeout: ProxyIdleConnTimeout, 119 | }, 120 | } 121 | } 122 | 123 | node := &Node{ 124 | log: log, 125 | URI: uri, 126 | AddedAt: time.Now(), 127 | jobC: jobC, 128 | numWorkers: numWorkers, 129 | client: &client, 130 | } 131 | return node, nil 132 | } 133 | -------------------------------------------------------------------------------- /server/node_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | "time" 9 | 10 | "github.com/flashbots/prio-load-balancer/testutils" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestNode(t *testing.T) { 15 | mockNodeBackend1 := testutils.NewMockNodeBackend() 16 | mockNodeServer1 := httptest.NewServer(http.HandlerFunc(mockNodeBackend1.Handler)) 17 | 18 | jobC := make(chan *SimRequest) 19 | node, err := NewNode(testLog, mockNodeServer1.URL, jobC, 1) 20 | require.Nil(t, err, err) 21 | 22 | err = node.HealthCheck() 23 | require.Nil(t, err, err) 24 | 25 | request := NewSimRequest(context.Background(), "1", []byte("foo"), true, false) 26 | node.StartWorkers() 27 | node.jobC <- request 28 | res := <-request.ResponseC 29 | require.NotNil(t, res, res) 30 | require.Nil(t, res.Error, res.Error) 31 | node.StopWorkersAndWait() 32 | require.Equal(t, int32(0), node.curWorkers) 33 | 34 | // Invalid backend -> fail healthcheck 35 | node, err = NewNode(testLog, "http://localhost:4831", nil, 1) 36 | require.Nil(t, err, err) 37 | 38 | err = node.HealthCheck() 39 | require.NotNil(t, err, err) 40 | } 41 | 42 | func TestNodeError(t *testing.T) { 43 | mockNodeBackend := testutils.NewMockNodeBackend() 44 | mockNodeServer := httptest.NewServer(http.HandlerFunc(mockNodeBackend.Handler)) 45 | mockNodeBackend.HTTPHandlerOverride = func(w http.ResponseWriter, req *http.Request) { 46 | http.Error(w, "error", 479) 47 | } 48 | 49 | jobC := make(chan *SimRequest) 50 | node, err := NewNode(testLog, mockNodeServer.URL, jobC, 1) 51 | require.Nil(t, err, err) 52 | 53 | // Check failing healthcheck 54 | err = node.HealthCheck() 55 | require.NotNil(t, err, err) 56 | require.Contains(t, err.Error(), "479") 57 | 58 | // Check failing ProxyRequest 59 | _, statusCode, err := node.ProxyRequest(context.Background(), []byte("net_version"), 3*time.Second) 60 | require.NotNil(t, err, err) 61 | require.Equal(t, 479, statusCode) 62 | 63 | // Check failing SimRequest 64 | request := NewSimRequest(context.Background(), "1", []byte("foo"), true, false) 65 | node.StartWorkers() 66 | node.jobC <- request 67 | res := <-request.ResponseC 68 | require.NotNil(t, res, res) 69 | require.NotNil(t, res.Error, res.Error) 70 | require.Contains(t, res.Error.Error(), "error") 71 | require.Contains(t, res.Error.Error(), "479") 72 | require.Equal(t, 479, res.StatusCode) 73 | } 74 | 75 | func TestWorkersArg(t *testing.T) { 76 | mockNodeBackend1 := testutils.NewMockNodeBackend() 77 | mockNodeServer1 := httptest.NewServer(http.HandlerFunc(mockNodeBackend1.Handler)) 78 | jobC := make(chan *SimRequest) 79 | 80 | node, err := NewNode(testLog, mockNodeServer1.URL, jobC, 1) 81 | require.Nil(t, err, err) 82 | require.Equal(t, int32(1), node.numWorkers) 83 | 84 | uriWithWorkers := mockNodeServer1.URL + "?_workers=4" 85 | node, err = NewNode(testLog, uriWithWorkers, jobC, 1) 86 | require.Nil(t, err, err) 87 | require.Equal(t, int32(4), node.numWorkers) 88 | 89 | uriWithWorkers = mockNodeServer1.URL + "?_workers=6" 90 | node, err = NewNode(testLog, uriWithWorkers, jobC, 1) 91 | require.Nil(t, err, err) 92 | require.Equal(t, int32(6), node.numWorkers) 93 | } 94 | -------------------------------------------------------------------------------- /server/nodepool.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/pkg/errors" 7 | "go.uber.org/zap" 8 | ) 9 | 10 | type NodePool struct { 11 | log *zap.SugaredLogger 12 | nodes []*Node 13 | nodesLock sync.Mutex 14 | redisState *RedisState 15 | numWorkersPerNode int32 16 | JobC chan *SimRequest 17 | } 18 | 19 | func NewNodePool(log *zap.SugaredLogger, redisState *RedisState, numWorkersPerNode int32) *NodePool { 20 | return &NodePool{ 21 | log: log, 22 | redisState: redisState, 23 | numWorkersPerNode: numWorkersPerNode, 24 | JobC: make(chan *SimRequest, JobChannelBuffer), 25 | } 26 | } 27 | 28 | func (gp *NodePool) LoadNodesFromRedis() error { 29 | if gp.redisState == nil { 30 | return nil 31 | } 32 | nodeUris, err := gp.redisState.GetNodes() 33 | if err != nil { 34 | return errors.Wrap(err, "loading nodes from redis failed") 35 | } 36 | gp.log.Infow("NodePool: loaded nodes from redis", "numNodes", len(nodeUris)) 37 | 38 | // Create the nodes now 39 | for _, uri := range nodeUris { 40 | _, _, err = gp._addNode(uri) 41 | if err != nil { 42 | return errors.Wrap(err, "adding node from redis failed") 43 | } 44 | } 45 | return nil 46 | } 47 | 48 | // HasNode returns true if a node with the URI is already in the pool 49 | func (gp *NodePool) HasNode(uri string) bool { 50 | for _, node := range gp.nodes { 51 | if node.URI == uri { 52 | return true 53 | } 54 | } 55 | return false 56 | } 57 | 58 | // AddNode adds a node to the pool and starts the workers. If a new node is added, the list of nodes is saved to redis. 59 | func (gp *NodePool) AddNode(uri string) error { 60 | added, nodeUris, err := gp._addNode(uri) 61 | if err != nil { 62 | return errors.Wrap(err, "AddNode failed") 63 | } 64 | 65 | if added { 66 | err = gp._saveNodeListToRedis(nodeUris) 67 | if err != nil { 68 | gp.log.Errorw("NodePool AddNode: added but failed saving to redis", "URI", uri, "error", err) 69 | } else { 70 | gp.log.Debugw("NodePool AddNode: added and saved to redis", "URI", uri, "numNodes", len(gp.nodes)) 71 | } 72 | } 73 | 74 | return err 75 | } 76 | 77 | // _addNode adds a node to the pool and starts the workers. If a new node is added, it also returns nodeUris to be saved to redis. 78 | func (gp *NodePool) _addNode(uri string) (added bool, nodeUris []string, err error) { 79 | gp.nodesLock.Lock() 80 | defer gp.nodesLock.Unlock() 81 | 82 | if gp.HasNode(uri) { 83 | return false, nil, nil 84 | } 85 | 86 | node, err := NewNode(gp.log, uri, gp.JobC, gp.numWorkersPerNode) 87 | if err != nil { 88 | return false, nil, err 89 | } 90 | 91 | err = node.HealthCheck() 92 | if err != nil { 93 | return false, nil, errors.Wrap(err, "_addNode healthcheck failed") 94 | } 95 | 96 | // Add now 97 | gp.nodes = append(gp.nodes, node) 98 | nodeUris = []string{} 99 | for _, node := range gp.nodes { 100 | nodeUris = append(nodeUris, node.URI) 101 | } 102 | 103 | // Start node workers 104 | node.StartWorkers() 105 | gp.log.Infow("NodePool: added node", "URI", uri, "numNodes", len(gp.nodes)) 106 | return true, nodeUris, nil 107 | } 108 | 109 | func (gp *NodePool) _saveNodeListToRedis(nodeUris []string) error { 110 | if gp.redisState == nil { 111 | return nil 112 | } 113 | 114 | return gp.redisState.SaveNodes(nodeUris) 115 | } 116 | 117 | func (gp *NodePool) DelNode(uri string) (deleted bool, err error) { 118 | for idx, node := range gp.nodes { 119 | if node.URI == uri { 120 | node.StopWorkers() 121 | 122 | gp.nodesLock.Lock() 123 | defer gp.nodesLock.Unlock() 124 | 125 | // Remove node 126 | gp.nodes = append(gp.nodes[:idx], gp.nodes[idx+1:]...) 127 | 128 | // Save new list of nodes to redis 129 | nodeUris := []string{} 130 | for _, node := range gp.nodes { 131 | nodeUris = append(nodeUris, node.URI) 132 | } 133 | err = gp._saveNodeListToRedis(nodeUris) 134 | return true, err 135 | } 136 | } 137 | return false, nil 138 | } 139 | 140 | func (gp *NodePool) NodeUris() []string { 141 | gp.nodesLock.Lock() 142 | defer gp.nodesLock.Unlock() 143 | 144 | nodeUris := []string{} 145 | for _, node := range gp.nodes { 146 | nodeUris = append(nodeUris, node.URI) 147 | } 148 | return nodeUris 149 | } 150 | 151 | // Shutdown will stop all node workers, but let's them finish the ongoing connections 152 | func (gp *NodePool) Shutdown() { 153 | for _, node := range gp.nodes { 154 | node.StopWorkersAndWait() 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /server/nodepool_test.go: -------------------------------------------------------------------------------- 1 | // Manages pool of execution nodes 2 | package server 3 | 4 | import ( 5 | "context" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | 10 | "github.com/flashbots/prio-load-balancer/testutils" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestNodePool(t *testing.T) { 15 | resetTestRedis() 16 | mockNodeBackend1 := testutils.NewMockNodeBackend() 17 | mockNodeServer1 := httptest.NewServer(http.HandlerFunc(mockNodeBackend1.Handler)) 18 | 19 | mockNodeBackend2 := testutils.NewMockNodeBackend() 20 | mockNodeServer2 := httptest.NewServer(http.HandlerFunc(mockNodeBackend2.Handler)) 21 | 22 | gp := NewNodePool(testLog, redisTestState, 1) 23 | err := gp.AddNode(mockNodeServer1.URL) 24 | require.Nil(t, err, err) 25 | 26 | err = gp.AddNode(mockNodeServer2.URL) 27 | require.Nil(t, err, err) 28 | 29 | nodes, err := redisTestState.GetNodes() 30 | require.Nil(t, err, err) 31 | require.Equal(t, 2, len(nodes)) 32 | 33 | gp2 := NewNodePool(testLog, redisTestState, 1) 34 | err = gp2.LoadNodesFromRedis() 35 | require.Nil(t, err, err) 36 | require.Equal(t, 2, len(gp2.nodes)) 37 | 38 | wasDeleted, err := gp2.DelNode(mockNodeServer1.URL) 39 | require.Nil(t, err, err) 40 | require.True(t, wasDeleted) 41 | } 42 | 43 | func TestNodePoolWithoutREDIS(t *testing.T) { 44 | mockNodeBackend1 := testutils.NewMockNodeBackend() 45 | mockNodeServer1 := httptest.NewServer(http.HandlerFunc(mockNodeBackend1.Handler)) 46 | 47 | mockNodeBackend2 := testutils.NewMockNodeBackend() 48 | mockNodeServer2 := httptest.NewServer(http.HandlerFunc(mockNodeBackend2.Handler)) 49 | 50 | gp := NewNodePool(testLog, nil, 1) 51 | err := gp.AddNode(mockNodeServer1.URL) 52 | require.Nil(t, err, err) 53 | 54 | err = gp.AddNode(mockNodeServer2.URL) 55 | require.Nil(t, err, err) 56 | } 57 | 58 | func TestNodePoolProxy(t *testing.T) { 59 | resetTestRedis() 60 | mockNodeBackend := testutils.NewMockNodeBackend() 61 | rpcBackendServer := httptest.NewServer(http.HandlerFunc(mockNodeBackend.Handler)) 62 | 63 | gp := NewNodePool(testLog, redisTestState, 1) 64 | err := gp.AddNode(rpcBackendServer.URL) 65 | require.Nil(t, err, err) 66 | 67 | request := NewSimRequest(context.Background(), "1", []byte("foo"), true, false) 68 | 69 | gp.JobC <- request 70 | res := <-request.ResponseC 71 | require.NotNil(t, res) 72 | require.Nil(t, res.Error, res.Error) 73 | require.Equal(t, 0, res.StatusCode) 74 | } 75 | 76 | func TestNodePoolWithError(t *testing.T) { 77 | mockNodeBackend := testutils.NewMockNodeBackend() 78 | mockNodeServer := httptest.NewServer(http.HandlerFunc(mockNodeBackend.Handler)) 79 | 80 | gp := NewNodePool(testLog, nil, 1) 81 | err := gp.AddNode(mockNodeServer.URL) 82 | require.Nil(t, err, err) 83 | 84 | mockNodeBackend.HTTPHandlerOverride = func(w http.ResponseWriter, req *http.Request) { 85 | http.Error(w, "error", 479) 86 | } 87 | 88 | request := NewSimRequest(context.Background(), "1", []byte("foo"), true, false) 89 | gp.JobC <- request 90 | res := <-request.ResponseC 91 | require.NotNil(t, res) 92 | require.NotNil(t, res.Error, res.Error) 93 | } 94 | -------------------------------------------------------------------------------- /server/queue.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | 7 | "go.uber.org/atomic" 8 | ) 9 | 10 | // PrioQueue has 3 queues: fastTrack, highPrio and lowPrio 11 | // - items will be popped 1:1 from fastTrack and highPrio, until both are empty 12 | // - then items from lowPrio queue are used 13 | // 14 | // maybe we should configure that every n-th item is used from low-prio? 15 | type PrioQueue struct { 16 | fastTrack []*SimRequest 17 | highPrio []*SimRequest 18 | lowPrio []*SimRequest 19 | 20 | cond *sync.Cond 21 | closed atomic.Bool 22 | nFastTrack atomic.Int32 23 | 24 | maxFastTrack int // max items for fast-track queue. 0 means no limit. 25 | maxHighPrio int // max items for high prio queue. 0 means no limit. 26 | maxLowPrio int // max items for low prio queue. 0 means no limit. 27 | 28 | numFastTrackForHighPrio int 29 | fastTrackDrainFirst bool 30 | } 31 | 32 | func NewPrioQueue(maxFastTrack, maxHighPrio, maxLowPrio, numFastTrackForHighPrio int, fastTrackDrainFirst bool) *PrioQueue { 33 | return &PrioQueue{ 34 | cond: sync.NewCond(&sync.Mutex{}), 35 | maxFastTrack: maxFastTrack, 36 | maxHighPrio: maxHighPrio, 37 | maxLowPrio: maxLowPrio, 38 | 39 | numFastTrackForHighPrio: numFastTrackForHighPrio, 40 | fastTrackDrainFirst: fastTrackDrainFirst, 41 | } 42 | } 43 | 44 | func (q *PrioQueue) Len() (lenFastTrack, lenHighPrio, lenLowPrio int) { 45 | return len(q.fastTrack), len(q.highPrio), len(q.lowPrio) 46 | } 47 | 48 | func (q *PrioQueue) NumRequests() int { 49 | return len(q.fastTrack) + len(q.highPrio) + len(q.lowPrio) 50 | } 51 | 52 | func (q *PrioQueue) String() string { 53 | return fmt.Sprintf("PrioQueue: fastTrack: %d / highPrio: %d / lowPrio: %d", len(q.fastTrack), len(q.highPrio), len(q.lowPrio)) 54 | } 55 | 56 | // Push adds a new item to the end of the queue. Returns true if added, false if queue is closed or at max capacity 57 | func (q *PrioQueue) Push(r *SimRequest) bool { 58 | if q.closed.Load() || r == nil { 59 | return false 60 | } 61 | 62 | // If queue limits are set and reached, return false now 63 | if r.IsFastTrack && q.maxFastTrack > 0 && len(q.fastTrack) >= q.maxFastTrack { 64 | return false 65 | } else if r.IsHighPrio && q.maxHighPrio > 0 && len(q.highPrio) >= q.maxHighPrio { 66 | return false 67 | } else if !r.IsHighPrio && q.maxLowPrio > 0 && len(q.lowPrio) >= q.maxLowPrio { 68 | return false 69 | } 70 | 71 | // Wait for the lock 72 | q.cond.L.Lock() 73 | defer q.cond.L.Unlock() 74 | 75 | // Check if closed in the meantime 76 | if q.closed.Load() { 77 | return false 78 | } 79 | 80 | // Add to the queue 81 | if r.IsFastTrack { 82 | q.fastTrack = append(q.fastTrack, r) 83 | } else if r.IsHighPrio { 84 | q.highPrio = append(q.highPrio, r) 85 | } else { 86 | q.lowPrio = append(q.lowPrio, r) 87 | } 88 | 89 | // Unlock and send signal to a listener 90 | q.cond.Signal() 91 | return true 92 | } 93 | 94 | // Pop returns the next Bid. If no task in queue, blocks until there is one again. First drains the high-prio queue, 95 | // then the low-prio one. Will return nil only after calling Close() when the queue is empty 96 | func (q *PrioQueue) Pop() (nextReq *SimRequest) { 97 | // Return nil immediately if queue is closed and empty 98 | if q.closed.Load() && len(q.fastTrack) == 0 && len(q.highPrio) == 0 && len(q.lowPrio) == 0 { 99 | return nil 100 | } 101 | 102 | q.cond.L.Lock() 103 | defer q.cond.L.Unlock() 104 | 105 | if len(q.fastTrack) == 0 && len(q.highPrio) == 0 && len(q.lowPrio) == 0 { 106 | if q.closed.Load() { 107 | return nil 108 | } 109 | 110 | q.cond.Wait() 111 | } 112 | 113 | // decide whether to start with fast-track or high-prio queue 114 | processFastTrack := len(q.fastTrack) > 0 115 | if !q.fastTrackDrainFirst { 116 | if processFastTrack { 117 | // only fast-track every so often 118 | if q.nFastTrack.Inc() > int32(q.numFastTrackForHighPrio) { 119 | q.nFastTrack.Store(0) 120 | processFastTrack = false 121 | } 122 | } else { 123 | q.nFastTrack.Store(0) 124 | } 125 | } 126 | 127 | if processFastTrack { // check fast-track queue first 128 | if len(q.fastTrack) > 0 { 129 | nextReq = q.fastTrack[0] 130 | q.fastTrack = q.fastTrack[1:] 131 | } else if len(q.highPrio) > 0 { 132 | nextReq = q.highPrio[0] 133 | q.highPrio = q.highPrio[1:] 134 | } else if len(q.lowPrio) > 0 { 135 | nextReq = q.lowPrio[0] 136 | q.lowPrio = q.lowPrio[1:] 137 | } 138 | } else { // check high-prio queue first 139 | if len(q.highPrio) > 0 { 140 | nextReq = q.highPrio[0] 141 | q.highPrio = q.highPrio[1:] 142 | } else if len(q.fastTrack) > 0 { 143 | nextReq = q.fastTrack[0] 144 | q.fastTrack = q.fastTrack[1:] 145 | } else if len(q.lowPrio) > 0 { 146 | nextReq = q.lowPrio[0] 147 | q.lowPrio = q.lowPrio[1:] 148 | } 149 | } 150 | 151 | // When closed and the last item was taken, signal to CloseAndWait that queue is now empty 152 | if q.closed.Load() && len(q.highPrio) == 0 && len(q.lowPrio) == 0 { 153 | q.cond.Broadcast() 154 | } 155 | 156 | return nextReq 157 | } 158 | 159 | // Close disallows adding any new items with Push(), and lets readers using Pop() return nil if queue is empty 160 | func (q *PrioQueue) Close() { 161 | q.closed.Store(true) 162 | if q.NumRequests() == 0 { 163 | q.cond.Broadcast() 164 | } 165 | } 166 | 167 | // CloseAndWait closes the queue and waits until the queue is empty 168 | func (q *PrioQueue) CloseAndWait() { 169 | q.Close() 170 | 171 | // Wait until queue is empty 172 | q.cond.L.Lock() 173 | if q.NumRequests() > 0 { 174 | q.cond.Wait() 175 | } 176 | q.cond.L.Unlock() 177 | } 178 | -------------------------------------------------------------------------------- /server/queue_test.go: -------------------------------------------------------------------------------- 1 | // Manages pool of execution nodes 2 | package server 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | "sync" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func cloneRequest(req *SimRequest) *SimRequest { 15 | return NewSimRequest(context.Background(), "1", req.Payload, req.IsHighPrio, req.IsFastTrack) 16 | } 17 | 18 | func fillQueue(t *testing.T, q *PrioQueue) { 19 | t.Helper() 20 | 21 | taskLowPrio := NewSimRequest(context.Background(), "1", []byte("taskLowPrio"), false, false) 22 | taskHighPrio := NewSimRequest(context.Background(), "1", []byte("taskHighPrio"), true, false) 23 | taskFastTrack := NewSimRequest(context.Background(), "1", []byte("tasFastTrack"), false, true) 24 | 25 | q.Push(taskLowPrio) 26 | q.Push(taskHighPrio) 27 | q.Push(cloneRequest(taskHighPrio)) 28 | q.Push(cloneRequest(taskHighPrio)) 29 | q.Push(cloneRequest(taskHighPrio)) 30 | q.Push(cloneRequest(taskHighPrio)) 31 | q.Push(cloneRequest(taskHighPrio)) 32 | q.Push(cloneRequest(taskHighPrio)) 33 | q.Push(cloneRequest(taskHighPrio)) 34 | q.Push(cloneRequest(taskHighPrio)) 35 | q.Push(cloneRequest(taskHighPrio)) 36 | q.Push(cloneRequest(taskHighPrio)) // 11x highPrio 37 | q.Push(taskFastTrack) 38 | q.Push(cloneRequest(taskFastTrack)) 39 | q.Push(cloneRequest(taskFastTrack)) 40 | q.Push(cloneRequest(taskFastTrack)) 41 | q.Push(cloneRequest(taskFastTrack)) // 5x fastTrack 42 | 43 | require.Equal(t, 5, len(q.fastTrack)) 44 | require.Equal(t, 11, len(q.highPrio)) 45 | require.Equal(t, 1, len(q.lowPrio)) 46 | } 47 | 48 | func TestQueueBlockingPop(t *testing.T) { 49 | q := NewPrioQueue(0, 0, 0, 2, false) 50 | taskLowPrio := NewSimRequest(context.Background(), "1", []byte("taskLowPrio"), false, false) 51 | 52 | // Ensure queue.Pop is blocking 53 | t1 := time.Now() 54 | go func() { time.Sleep(100 * time.Millisecond); q.Push(taskLowPrio) }() 55 | resp := q.Pop() 56 | tX := time.Since(t1) 57 | require.NotNil(t, resp) 58 | require.True(t, tX >= 100*time.Millisecond) 59 | } 60 | 61 | func TestQueuePopping(t *testing.T) { 62 | // Test 1 - expected: fastTrack -> highPrio -> fastTrack -> highPrio 63 | q := NewPrioQueue(0, 0, 0, 1, false) 64 | fillQueue(t, q) 65 | for i := 0; i < 5; i++ { 66 | x := q.Pop() 67 | fmt.Println("fast:", x.IsFastTrack, "high-prio:", x.IsHighPrio) 68 | require.Equal(t, true, x.IsFastTrack) 69 | require.Equal(t, true, q.Pop().IsHighPrio) 70 | } 71 | 72 | // next 9 should all be high-prio 73 | for i := 0; i < 6; i++ { 74 | require.Equal(t, true, q.Pop().IsHighPrio) 75 | } 76 | 77 | // last one should be low-prio 78 | require.Equal(t, false, q.Pop().IsHighPrio) 79 | require.Equal(t, 0, len(q.lowPrio)) 80 | require.Equal(t, 0, len(q.highPrio)) 81 | 82 | // Test 2 - expected: 2x fastTrack -> 1x highPrio 83 | q = NewPrioQueue(0, 0, 0, 2, false) 84 | fillQueue(t, q) 85 | require.Equal(t, true, q.Pop().IsFastTrack) 86 | require.Equal(t, true, q.Pop().IsFastTrack) 87 | require.Equal(t, true, q.Pop().IsHighPrio) 88 | require.Equal(t, true, q.Pop().IsFastTrack) 89 | require.Equal(t, true, q.Pop().IsFastTrack) 90 | require.Equal(t, true, q.Pop().IsHighPrio) 91 | require.Equal(t, true, q.Pop().IsFastTrack) 92 | 93 | // Test 3 - expected: all fastTrack -> all highPrio 94 | q = NewPrioQueue(0, 0, 0, 2, true) 95 | fillQueue(t, q) 96 | for i := 0; i < 5; i++ { 97 | require.Equal(t, true, q.Pop().IsFastTrack) 98 | } 99 | for i := 0; i < 11; i++ { 100 | require.Equal(t, true, q.Pop().IsHighPrio) 101 | } 102 | } 103 | 104 | func TestPrioQueueMultipleReaders(t *testing.T) { 105 | q := NewPrioQueue(0, 0, 0, 2, false) 106 | taskLowPrio := NewSimRequest(context.Background(), "1", []byte("taskLowPrio"), false, false) 107 | 108 | counts := make(map[int]int) 109 | resultC := make(chan int, 4) 110 | 111 | // Goroutine that counts the results 112 | go func() { 113 | for id := range resultC { 114 | counts[id]++ 115 | } 116 | }() 117 | 118 | reader := func(id int) { 119 | for { 120 | resp := q.Pop() 121 | require.NotNil(t, resp) 122 | resultC <- id 123 | time.Sleep(10 * time.Millisecond) 124 | } 125 | } 126 | 127 | // Start 2 readers 128 | go reader(1) 129 | go reader(2) 130 | 131 | // Push 6 tasks 132 | q.Push(taskLowPrio) 133 | q.Push(taskLowPrio) 134 | q.Push(taskLowPrio) 135 | q.Push(taskLowPrio) 136 | q.Push(taskLowPrio) 137 | q.Push(taskLowPrio) 138 | 139 | // Wait a bit for the processing to finish 140 | time.Sleep(100 * time.Millisecond) 141 | 142 | // Each reader should have processed the same number of tasks 143 | require.Equal(t, 3, counts[1]) 144 | require.Equal(t, 3, counts[2]) 145 | } 146 | 147 | func TestPrioQueueVarious(t *testing.T) { 148 | q := NewPrioQueue(0, 0, 0, 2, false) 149 | q.Push(nil) 150 | require.Equal(t, 0, len(q.highPrio)) 151 | require.Equal(t, 0, len(q.lowPrio)) 152 | 153 | require.True(t, len(q.String()) > 5) 154 | } 155 | 156 | // Test used for benchmark: single reader 157 | func _testPrioQueue1(numWorkers, numItems int) *PrioQueue { 158 | q := NewPrioQueue(0, 0, 0, 2, false) 159 | taskLowPrio := NewSimRequest(context.Background(), "1", []byte("taskLowPrio"), false, false) 160 | 161 | var wg sync.WaitGroup 162 | 163 | // Goroutine that drains the queue 164 | for i := 0; i < numWorkers; i++ { 165 | wg.Add(1) 166 | go func() { 167 | defer wg.Done() 168 | for { 169 | resp := q.Pop() 170 | if resp == nil { 171 | return 172 | } 173 | } 174 | }() 175 | } 176 | 177 | for i := 0; i < numItems; i++ { 178 | q.Push(taskLowPrio) 179 | } 180 | 181 | q.CloseAndWait() 182 | wg.Wait() // ensure that all workers have finished 183 | return q 184 | } 185 | 186 | func TestPrioQueue1(t *testing.T) { 187 | q := _testPrioQueue1(1, 1000) 188 | require.Equal(t, 0, q.NumRequests()) 189 | 190 | q = _testPrioQueue1(5, 100) 191 | require.Equal(t, 0, q.NumRequests()) 192 | } 193 | 194 | func BenchmarkPrioQueue(b *testing.B) { 195 | for i := 0; i < b.N; i++ { 196 | _testPrioQueue1(1, 10_000) 197 | } 198 | } 199 | 200 | func BenchmarkPrioQueueMultiReader(b *testing.B) { 201 | for i := 0; i < b.N; i++ { 202 | _testPrioQueue1(5, 10_000) 203 | } 204 | } 205 | -------------------------------------------------------------------------------- /server/redis.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | 7 | "github.com/go-redis/redis/v8" 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | var RedisKeyNodes = RedisPrefix + "prio-load-balancer:nodes" 12 | 13 | type RedisState struct { 14 | RedisClient *redis.Client 15 | } 16 | 17 | func NewRedisState(redisURI string) (*RedisState, error) { 18 | redisClient := redis.NewClient(&redis.Options{Addr: redisURI}) 19 | if err := redisClient.Get(context.Background(), "somekey").Err(); err != nil && err != redis.Nil { 20 | return nil, errors.Wrap(err, "redis init error") 21 | } 22 | return &RedisState{ 23 | RedisClient: redisClient, 24 | }, nil 25 | } 26 | 27 | func (s *RedisState) SaveNodes(nodeUris []string) error { 28 | msg, err := json.Marshal(nodeUris) 29 | if err != nil { 30 | return err 31 | } 32 | err = s.RedisClient.Set(context.Background(), RedisKeyNodes, msg, 0).Err() 33 | return err 34 | } 35 | 36 | func (s *RedisState) GetNodes() (nodeUris []string, err error) { 37 | res, err := s.RedisClient.Get(context.Background(), RedisKeyNodes).Result() 38 | if err != nil { 39 | if err == redis.Nil { 40 | return nodeUris, nil 41 | } 42 | return nil, err 43 | } 44 | 45 | err = json.Unmarshal([]byte(res), &nodeUris) 46 | if err != nil { 47 | return nil, err 48 | } 49 | 50 | return nodeUris, nil 51 | } 52 | -------------------------------------------------------------------------------- /server/redis_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/alicebob/miniredis" 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | var ( 11 | redisTestServer *miniredis.Miniredis 12 | redisTestState *RedisState 13 | ) 14 | 15 | func resetTestRedis() { 16 | var err error 17 | if redisTestServer != nil { 18 | redisTestServer.Close() 19 | } 20 | 21 | redisTestServer, err = miniredis.Run() 22 | if err != nil { 23 | panic(err) 24 | } 25 | 26 | redisTestState, err = NewRedisState(redisTestServer.Addr()) 27 | if err != nil { 28 | panic(err) 29 | } 30 | } 31 | 32 | func TestRedisStateSetup(t *testing.T) { 33 | var err error 34 | _, err = NewRedisState("localhost:18279") 35 | require.NotNil(t, err, err) 36 | } 37 | 38 | func TestRedisNodes(t *testing.T) { 39 | resetTestRedis() 40 | 41 | nodes0, err := redisTestState.GetNodes() 42 | require.Nil(t, err, err) 43 | require.Equal(t, 0, len(nodes0)) 44 | 45 | err = redisTestState.SaveNodes([]string{"http://localhost:12431", "http://localhost:12432"}) 46 | require.Nil(t, err, err) 47 | 48 | nodes2, err := redisTestState.GetNodes() 49 | require.Nil(t, err, err) 50 | require.Equal(t, 2, len(nodes2)) 51 | } 52 | -------------------------------------------------------------------------------- /server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "go.uber.org/zap" 8 | ) 9 | 10 | type ServerOpts struct { 11 | Log *zap.SugaredLogger 12 | HTTPAddrPtr string // listen address for the webserver 13 | RedisURI string // (optional) URI for the redis instance. If empty then don't use Redis. 14 | WorkersPerNode int32 // Number of concurrent workers per execution node 15 | } 16 | 17 | // Server is the overall load balancer server 18 | type Server struct { 19 | log *zap.SugaredLogger 20 | opts ServerOpts 21 | redis *RedisState 22 | prioQueue *PrioQueue 23 | nodePool *NodePool 24 | webserver *Webserver 25 | } 26 | 27 | // NewServer creates a new Server instance, loads the nodes from Redis and starts the node workers 28 | func NewServer(opts ServerOpts) (*Server, error) { 29 | var err error 30 | s := Server{ 31 | opts: opts, 32 | log: opts.Log, 33 | prioQueue: NewPrioQueue(MaxQueueItemsFastTrack, MaxQueueItemsHighPrio, MaxQueueItemsLowPrio, FastTrackPerHighPrio, FastTrackDrainFirst), 34 | } 35 | 36 | if s.opts.RedisURI == "" { 37 | s.log.Info("Not using Redis because no RedisURI provided") 38 | } else { 39 | s.log.Infow("Connecting to Redis", "URI", s.opts.RedisURI) 40 | s.redis, err = NewRedisState(s.opts.RedisURI) 41 | if err != nil { 42 | return nil, err 43 | } 44 | } 45 | 46 | if opts.WorkersPerNode == 0 { 47 | s.log.Warn("WorkersPerNode is 0! This is not recommended. Use at least 1.") 48 | } 49 | 50 | s.nodePool = NewNodePool(s.log, s.redis, s.opts.WorkersPerNode) 51 | err = s.nodePool.LoadNodesFromRedis() 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | return &s, nil 57 | } 58 | 59 | // Start starts the webserver and the main loop (pumping jobs from the queue to the workers) 60 | func (s *Server) Start() { 61 | // Setup and start the webserver 62 | s.log.Infow("Starting webserver", "listenAddr", s.opts.HTTPAddrPtr) 63 | s.webserver = NewWebserver(s.log, s.opts.HTTPAddrPtr, s.prioQueue, s.nodePool) 64 | s.webserver.Start() 65 | 66 | // Main loop: send simqueue jobs to node pool 67 | s.log.Info("Starting main loop") 68 | for { 69 | r := s.prioQueue.Pop() 70 | if r == nil { // Shutdown (queue.Close() was called) 71 | s.log.Info("Shutting down main loop (request is nil)") 72 | return 73 | } 74 | 75 | if r.Cancelled { 76 | continue 77 | } 78 | 79 | if time.Since(r.CreatedAt) > RequestTimeout { 80 | s.log.Info("request timed out before processing") 81 | r.SendResponse(SimResponse{Error: ErrRequestTimeout}) 82 | continue 83 | } 84 | 85 | // Return an error if no nodes are available 86 | if len(s.nodePool.nodes) == 0 { 87 | s.log.Error("no execution nodes available") 88 | r.SendResponse(SimResponse{Error: ErrNoNodesAvailable}) 89 | continue 90 | } 91 | 92 | // Forward to a node for processing 93 | select { 94 | case s.nodePool.JobC <- r: 95 | // Job was taken by a node 96 | case <-time.After(ServerJobSendTimeout): 97 | // Job was NOT taken by a node - cancel request 98 | s.log.Warnw("job was not taken by a node", "requestsInQueue", s.prioQueue.NumRequests()) 99 | r.SendResponse(SimResponse{Error: ErrNodeTimeout}) 100 | } 101 | } 102 | } 103 | 104 | // Shutdown gracefully shuts down the server. Allows ongoing requests to complete, but no 105 | // further requests will be accepted or those from the queue processed. 106 | func (s *Server) Shutdown() { 107 | s.log.Info("Shutting down server") 108 | s.prioQueue.Close() 109 | s.webserver.srv.Shutdown(context.Background()) // stop incoming requests 110 | s.nodePool.Shutdown() // stop the execution workers 111 | } 112 | 113 | // AddNode adds a new execution node to the pool and starts the workers. If a new node is added, 114 | // the list of nodes is saved to redis. 115 | func (s *Server) AddNode(uri string) error { 116 | return s.nodePool.AddNode(uri) 117 | } 118 | 119 | // NumNodeWorkersAlive returns the number of currently active node workers 120 | func (s *Server) NumNodeWorkersAlive() int { 121 | res := 0 122 | for _, n := range s.nodePool.nodes { 123 | res += int(n.curWorkers) 124 | } 125 | return res 126 | } 127 | 128 | func (s *Server) QueueSize() (lenFastTrack, lenHighPrio, lenLowPrio int) { 129 | return s.prioQueue.Len() 130 | } 131 | -------------------------------------------------------------------------------- /server/server_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "io" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | "time" 11 | 12 | "github.com/flashbots/prio-load-balancer/testutils" 13 | "github.com/stretchr/testify/require" 14 | "go.uber.org/zap" 15 | ) 16 | 17 | var ( 18 | testServerListenAddr = "localhost:9498" 19 | testLogger, _ = zap.NewDevelopment() 20 | testLog = testLogger.Sugar() 21 | ) 22 | 23 | func TestServerWithoutRedis(t *testing.T) { 24 | s, err := NewServer(ServerOpts{testLog, testServerListenAddr, "", 1}) 25 | require.Nil(t, err, err) 26 | 27 | mockNodeBackend := testutils.NewMockNodeBackend() 28 | mockNodeServer := httptest.NewServer(http.HandlerFunc(mockNodeBackend.Handler)) 29 | s.AddNode(mockNodeServer.URL) 30 | go s.Start() 31 | defer s.Shutdown() 32 | time.Sleep(100 * time.Millisecond) // give Github CI time to start the webserver 33 | 34 | url := "http://" + testServerListenAddr 35 | resp, err := http.PostForm(url, nil) 36 | require.Nil(t, err, err) 37 | require.Equal(t, 200, resp.StatusCode) 38 | 39 | url = "http://" + testServerListenAddr + "/" 40 | resp, err = http.PostForm(url, nil) 41 | require.Nil(t, err, err) 42 | require.Equal(t, 200, resp.StatusCode) 43 | } 44 | 45 | func TestServerWithRedis(t *testing.T) { 46 | resetTestRedis() 47 | 48 | s, err := NewServer(ServerOpts{testLog, testServerListenAddr, redisTestServer.Addr(), 1}) 49 | require.Nil(t, err, err) 50 | 51 | mockNodeBackend := testutils.NewMockNodeBackend() 52 | mockNodeServer := httptest.NewServer(http.HandlerFunc(mockNodeBackend.Handler)) 53 | s.AddNode(mockNodeServer.URL) 54 | go s.Start() 55 | defer s.Shutdown() 56 | time.Sleep(100 * time.Millisecond) // give Github CI time to start the webserver 57 | 58 | url := "http://" + testServerListenAddr + "/" 59 | resp, err := http.PostForm(url, nil) 60 | require.Nil(t, err, err) 61 | require.Equal(t, 200, resp.StatusCode) 62 | } 63 | 64 | func TestServerNoNodes(t *testing.T) { 65 | s, err := NewServer(ServerOpts{testLog, testServerListenAddr, "", 1}) 66 | require.Nil(t, err, err) 67 | go s.Start() 68 | defer s.Shutdown() 69 | time.Sleep(100 * time.Millisecond) // give Github CI time to start the webserver 70 | 71 | url := "http://" + testServerListenAddr 72 | resp, err := http.PostForm(url, nil) 73 | require.Nil(t, err, err) 74 | require.Equal(t, 500, resp.StatusCode) 75 | 76 | bb, _ := io.ReadAll(resp.Body) 77 | require.Contains(t, string(bb), "no nodes") 78 | } 79 | 80 | // TestServerShutdown tests the graceful shutdown of the server 81 | func TestServerShutdown(t *testing.T) { 82 | s, err := NewServer(ServerOpts{testLog, testServerListenAddr, "", 1}) 83 | require.Nil(t, err, err) 84 | 85 | done := make(chan bool) 86 | go func() { 87 | s.Start() 88 | done <- true 89 | }() 90 | 91 | time.Sleep(100 * time.Millisecond) 92 | s.Shutdown() 93 | isDone := <-done 94 | require.True(t, isDone) 95 | } 96 | 97 | // TestServerJobTimeout ensures that the server will timeout a job if it takes too long 98 | func TestServerJobTimeout(t *testing.T) { 99 | s, err := NewServer(ServerOpts{testLog, testServerListenAddr, "", 0}) // 0 workers per node -> no jobs can be picked up 100 | require.Nil(t, err, err) 101 | s.nodePool.JobC = make(chan *SimRequest) // disable buffer on job queue 102 | 103 | mockNodeBackend := testutils.NewMockNodeBackend() 104 | mockNodeServer := httptest.NewServer(http.HandlerFunc(mockNodeBackend.Handler)) 105 | s.AddNode(mockNodeServer.URL) 106 | go s.Start() 107 | defer s.Shutdown() 108 | time.Sleep(100 * time.Millisecond) // give Github CI time to start the webserver 109 | 110 | url := "http://" + testServerListenAddr 111 | reqPayload := testutils.NewJSONRPCRequest1(1, "eth_callBundle", "0x1") 112 | reqPayloadBytes, err := json.Marshal(reqPayload) 113 | require.Nil(t, err, err) 114 | resp, _ := http.Post(url, "application/json", bytes.NewBuffer(reqPayloadBytes)) 115 | require.Nil(t, err, err) 116 | require.Equal(t, 500, resp.StatusCode) 117 | lenFT, lenHP, lenLP := s.prioQueue.Len() 118 | require.Equal(t, 0, lenFT+lenHP+lenLP) 119 | } 120 | -------------------------------------------------------------------------------- /server/types.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "time" 6 | ) 7 | 8 | type SimRequest struct { 9 | // can be none of, or one of high-prio / fast-track 10 | ID string 11 | IsHighPrio bool 12 | IsFastTrack bool 13 | 14 | Payload []byte 15 | ResponseC chan SimResponse 16 | Cancelled bool 17 | CreatedAt time.Time 18 | Tries int 19 | Context context.Context 20 | } 21 | 22 | func NewSimRequest(ctx context.Context, id string, payload []byte, isHighPrio, IsFastTrack bool) *SimRequest { 23 | return &SimRequest{ 24 | ID: id, 25 | Payload: payload, 26 | IsHighPrio: isHighPrio, 27 | IsFastTrack: IsFastTrack, 28 | ResponseC: make(chan SimResponse, 1), 29 | CreatedAt: time.Now().UTC(), 30 | Context: ctx, 31 | } 32 | } 33 | 34 | // SendResponse sends the response to ResponseC. If noone is listening on the channel, it is dropped. 35 | func (r *SimRequest) SendResponse(resp SimResponse) (wasSent bool) { 36 | select { 37 | case r.ResponseC <- resp: 38 | return true 39 | default: 40 | return false 41 | } 42 | } 43 | 44 | type SimResponse struct { 45 | StatusCode int 46 | Payload []byte 47 | Error error 48 | ShouldRetry bool // When response has an error, whether it should be retried 49 | NodeURI string 50 | SimDuration time.Duration 51 | SimAt time.Time // time when proxying started 52 | } 53 | -------------------------------------------------------------------------------- /server/utils.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "os" 5 | "strconv" 6 | ) 7 | 8 | func GetEnvInt(key string, defaultValue int) int { 9 | if value, ok := os.LookupEnv(key); ok { 10 | val, err := strconv.Atoi(value) 11 | if err == nil { 12 | return val 13 | } 14 | } 15 | return defaultValue 16 | } 17 | 18 | func GetEnv(key, defaultValue string) string { 19 | if value, ok := os.LookupEnv(key); ok { 20 | return value 21 | } 22 | return defaultValue 23 | } 24 | -------------------------------------------------------------------------------- /server/webserver.go: -------------------------------------------------------------------------------- 1 | // Package server is the webserver which sends simulation requests to the simulator. 2 | package server 3 | 4 | import ( 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | _ "net/http/pprof" 10 | "strings" 11 | "time" 12 | 13 | "github.com/gorilla/mux" 14 | "go.uber.org/zap" 15 | ) 16 | 17 | type Webserver struct { 18 | log *zap.SugaredLogger 19 | listenAddr string 20 | prioQueue *PrioQueue 21 | nodePool *NodePool 22 | srv *http.Server 23 | } 24 | 25 | func NewWebserver(log *zap.SugaredLogger, listenAddr string, prioQueue *PrioQueue, nodePool *NodePool) *Webserver { 26 | return &Webserver{ 27 | log: log, 28 | listenAddr: listenAddr, 29 | prioQueue: prioQueue, 30 | nodePool: nodePool, 31 | } 32 | } 33 | 34 | func (s *Webserver) Start() { 35 | r := mux.NewRouter() 36 | r.HandleFunc("/", s.HandleRootRequest).Methods(http.MethodGet) 37 | r.HandleFunc("/", s.HandleQueueRequest).Methods(http.MethodPost) 38 | r.HandleFunc("/sim", s.HandleQueueRequest).Methods(http.MethodPost) 39 | r.HandleFunc("/nodes", s.HandleNodesRequest).Methods(http.MethodGet, http.MethodPost, http.MethodDelete) 40 | 41 | if EnablePprof { 42 | s.log.Info("Enabling pprof") 43 | r.PathPrefix("/debug/pprof/").Handler(http.DefaultServeMux) 44 | } 45 | 46 | if EnableErrorTestAPI { 47 | s.log.Info("Enabling error testing API") 48 | r.HandleFunc("/debug/testLogLevels", s.HandleTestLogLevels).Methods(http.MethodGet) 49 | } 50 | 51 | loggedRouter := LoggingMiddleware(s.log, r) 52 | 53 | s.srv = &http.Server{ 54 | Addr: s.listenAddr, 55 | Handler: loggedRouter, 56 | } 57 | 58 | go func() { 59 | err := s.srv.ListenAndServe() 60 | if err == http.ErrServerClosed { 61 | return 62 | } 63 | s.log.Errorw("Webserver error", "err", err) 64 | panic(err) 65 | }() 66 | } 67 | 68 | func (s *Webserver) HandleRootRequest(w http.ResponseWriter, req *http.Request) { 69 | w.WriteHeader(http.StatusOK) 70 | fmt.Fprintf(w, "prio-load-balancer\n") 71 | } 72 | 73 | func (s *Webserver) HandleQueueRequest(w http.ResponseWriter, req *http.Request) { 74 | startTime := time.Now().UTC() 75 | defer req.Body.Close() 76 | 77 | // Allow single `X-Request-ID:...` log field via header 78 | reqID := req.Header.Get("X-Request-ID") 79 | log := s.log 80 | if reqID != "" { 81 | log = s.log.With("reqID", reqID) 82 | } 83 | 84 | // Read the body and start processing 85 | body, err := io.ReadAll(req.Body) 86 | if err != nil { 87 | http.Error(w, err.Error(), http.StatusInternalServerError) 88 | return 89 | } 90 | 91 | if len(body) > PayloadMaxBytes { 92 | http.Error(w, "Payload too large", http.StatusBadRequest) 93 | return 94 | } 95 | 96 | ctx := req.Context() 97 | if ctx.Err() != nil { 98 | log.Infow("client closed the connection before processing", "err", ctx.Err()) 99 | return 100 | } 101 | 102 | // Add new sim request to queue 103 | isFastTrack := req.Header.Get("X-Fast-Track") == "true" 104 | isHighPrio := req.Header.Get("high_prio") == "true" || req.Header.Get("X-High-Priority") == "true" 105 | simReq := NewSimRequest(ctx, reqID, body, isHighPrio, isFastTrack) 106 | wasAdded := s.prioQueue.Push(simReq) 107 | if !wasAdded { // queue was full, job not added 108 | log.Error("Couldn't add request, queue is full") 109 | http.Error(w, "queue full", http.StatusInternalServerError) 110 | return 111 | } 112 | 113 | startQueueSizeFastTrack, startQueueSizeHighPrio, startQueueSizeLowPrio := s.prioQueue.Len() 114 | startItemQueueSize := startQueueSizeLowPrio 115 | if isFastTrack { 116 | startItemQueueSize = startQueueSizeFastTrack 117 | } else if isHighPrio { 118 | startItemQueueSize = startQueueSizeHighPrio 119 | } 120 | 121 | log = log.With( 122 | "requestIsHighPrio", isHighPrio, 123 | "requestIsFastTrack", isFastTrack, 124 | "payloadSize", len(body), 125 | 126 | "startQueueSize", s.prioQueue.NumRequests(), 127 | "startQueueSizeFastTrack", startQueueSizeFastTrack, 128 | "startQueueSizeHighPrio", startQueueSizeHighPrio, 129 | "startQueueSizeLowPrio", startQueueSizeLowPrio, 130 | ) 131 | log.Infow("Request added to queue") 132 | 133 | // Wait for response or cancel 134 | for { 135 | select { 136 | case <-ctx.Done(): // if user closes connection, cancel the simreq 137 | log.Infow("Client closed the connection prematurely", "err", ctx.Err(), "queueItems", s.prioQueue.NumRequests(), "payloadSize", len(body), "requestTries", simReq.Tries, "requestCancelled", simReq.Cancelled) 138 | if ctx.Err() != nil { 139 | simReq.Cancelled = true 140 | } 141 | return 142 | case resp := <-simReq.ResponseC: 143 | if resp.Error != nil { 144 | log.Infow("Request proxying failed", "err", resp.Error, "try", simReq.Tries, "shouldRetry", resp.ShouldRetry, "nodeURI", resp.NodeURI) 145 | if simReq.Tries < RequestMaxTries && resp.ShouldRetry { 146 | s.prioQueue.Push(simReq) 147 | continue 148 | } 149 | 150 | if resp.StatusCode == 0 { 151 | resp.StatusCode = http.StatusInternalServerError 152 | } 153 | 154 | if len(resp.Payload) > 0 { 155 | w.WriteHeader(resp.StatusCode) 156 | w.Write(resp.Payload) 157 | return 158 | } 159 | 160 | http.Error(w, strings.Trim(resp.Error.Error(), "\n"), resp.StatusCode) 161 | return 162 | } 163 | 164 | if resp.StatusCode == 0 { 165 | resp.StatusCode = http.StatusOK 166 | } 167 | 168 | queueDurationUs := resp.SimAt.Sub(startTime).Microseconds() 169 | endQueueSizeFastTrack, endQueueSizeHighPrio, endQueueSizeLowPrio := s.prioQueue.Len() 170 | endItemQueueSize := endQueueSizeLowPrio 171 | if isFastTrack { 172 | endItemQueueSize = endQueueSizeFastTrack 173 | } else if isHighPrio { 174 | endItemQueueSize = endQueueSizeHighPrio 175 | } 176 | 177 | // Add additional profiling information about this request as part of the response headers 178 | w.Header().Set("X-PrioLB-QueueDurationUs", fmt.Sprint(queueDurationUs)) 179 | w.Header().Set("X-PrioLB-SimDurationUs", fmt.Sprint(resp.SimDuration.Microseconds())) 180 | w.Header().Set("X-PrioLB-TotalDurationUs", fmt.Sprint(time.Since(startTime).Microseconds())) 181 | w.Header().Set("X-PrioLB-QueueSizeStart", fmt.Sprint(startItemQueueSize)) 182 | w.Header().Set("X-PrioLB-QueueSizeEnd", fmt.Sprint(endItemQueueSize)) 183 | 184 | // Send the response 185 | w.Header().Set("Content-Type", "application/json") 186 | w.WriteHeader(resp.StatusCode) 187 | w.Write(resp.Payload) 188 | 189 | log.Infow("Request completed", 190 | "durationMs", time.Since(startTime).Milliseconds(), // full request duration in milliseconds 191 | "durationUs", time.Since(startTime).Microseconds(), // full request duration in microseconds 192 | "simDurationUs", resp.SimDuration.Microseconds(), // time only for simulation (proxying) 193 | "queueDurationUs", queueDurationUs, // time until request was proxied (queue wait time) 194 | 195 | "statusCode", resp.StatusCode, 196 | "nodeURI", resp.NodeURI, 197 | "requestTries", simReq.Tries, 198 | 199 | "endQueueSize", s.prioQueue.NumRequests(), 200 | "endQueueSizeFastTrack", endQueueSizeFastTrack, 201 | "endQueueSizeHighPrio", endQueueSizeHighPrio, 202 | "endQueueSizeLowPrio", endQueueSizeLowPrio, 203 | ) 204 | return 205 | } 206 | } 207 | } 208 | 209 | type NodeURIPayload struct { 210 | URI string `json:"uri"` 211 | } 212 | 213 | func (s *Webserver) HandleNodesRequest(w http.ResponseWriter, req *http.Request) { 214 | if req.Method == "GET" { 215 | w.Header().Set("Content-Type", "application/json") 216 | if err := json.NewEncoder(w).Encode(s.nodePool.NodeUris()); err != nil { 217 | http.Error(w, err.Error(), http.StatusInternalServerError) 218 | return 219 | } 220 | 221 | } else if req.Method == "POST" { 222 | var payload NodeURIPayload 223 | if err := json.NewDecoder(req.Body).Decode(&payload); err != nil { 224 | http.Error(w, err.Error(), http.StatusBadRequest) 225 | return 226 | } 227 | 228 | if err := s.nodePool.AddNode(payload.URI); err != nil { 229 | http.Error(w, err.Error(), http.StatusBadRequest) 230 | return 231 | } 232 | 233 | w.WriteHeader(http.StatusOK) 234 | 235 | } else if req.Method == "DELETE" { 236 | var payload NodeURIPayload 237 | if err := json.NewDecoder(req.Body).Decode(&payload); err != nil { 238 | http.Error(w, err.Error(), http.StatusBadRequest) 239 | return 240 | } 241 | 242 | wasRemoved, err := s.nodePool.DelNode(payload.URI) 243 | if err != nil { 244 | http.Error(w, err.Error(), http.StatusBadRequest) 245 | return 246 | } 247 | 248 | if !wasRemoved { 249 | http.Error(w, "node not found", http.StatusBadRequest) 250 | return 251 | } 252 | 253 | w.WriteHeader(http.StatusOK) 254 | } 255 | } 256 | 257 | // HandleTestLogLevels is used for testing error logging, to verify for operations. Is opt-in with `ENABLE_ERROR_TEST_API=1` 258 | func (s *Webserver) HandleTestLogLevels(w http.ResponseWriter, req *http.Request) { 259 | s.log.Debug("debug") 260 | s.log.Infow("info", "key", "value") 261 | s.log.Warnw("warn", "key", "value") 262 | s.log.Errorw("error", "key", "value") 263 | // s.log.Fatalw("fatal", "key", "value") 264 | // s.log.Panicw("panic", "key", "value") 265 | panic("panic") 266 | // w.WriteHeader(http.StatusOK) 267 | } 268 | -------------------------------------------------------------------------------- /server/webserver_test.go: -------------------------------------------------------------------------------- 1 | // Manages pool of execution nodes 2 | package server 3 | 4 | import ( 5 | "bytes" 6 | "context" 7 | "encoding/json" 8 | "errors" 9 | "fmt" 10 | "net/http" 11 | "net/http/httptest" 12 | "testing" 13 | "time" 14 | 15 | "github.com/flashbots/prio-load-balancer/testutils" 16 | "github.com/stretchr/testify/require" 17 | ) 18 | 19 | func TestWebserver(t *testing.T) { 20 | resetTestRedis() 21 | 22 | prioQueue := NewPrioQueue(0, 0, 0, 2, false) 23 | nodePool := NewNodePool(testLog, redisTestState, 1) 24 | webserver := NewWebserver(testLog, ":12345", prioQueue, nodePool) 25 | 26 | // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response. 27 | handler := http.HandlerFunc(webserver.HandleNodesRequest) 28 | 29 | mockNodeBackend := testutils.NewMockNodeBackend() 30 | mockNodeServer := httptest.NewServer(http.HandlerFunc(mockNodeBackend.Handler)) 31 | 32 | // GET /nodes request (empty list) 33 | getNodesReq, _ := http.NewRequest("GET", "/nodes", nil) 34 | rr := httptest.NewRecorder() 35 | handler.ServeHTTP(rr, getNodesReq) 36 | require.Equal(t, http.StatusOK, rr.Code) 37 | require.Equal(t, "[]\n", rr.Body.String()) 38 | 39 | // Add a node with POST /nodes request 40 | addNodePayload := fmt.Sprintf(`{"uri":"%s"}`, mockNodeServer.URL) 41 | addNodeReq, _ := http.NewRequest("POST", "/nodes", bytes.NewBufferString(addNodePayload)) 42 | rr = httptest.NewRecorder() 43 | handler.ServeHTTP(rr, addNodeReq) 44 | require.Equal(t, http.StatusOK, rr.Code) 45 | require.Equal(t, 1, len(nodePool.nodes)) 46 | 47 | // check redis 48 | nodesFromRedis, err := redisTestState.GetNodes() 49 | require.Nil(t, err, err) 50 | require.Equal(t, 1, len(nodesFromRedis)) 51 | 52 | // Get list of nodes with length 1 53 | getNodesReq, _ = http.NewRequest("GET", "/nodes", nil) 54 | rr = httptest.NewRecorder() 55 | handler.ServeHTTP(rr, getNodesReq) 56 | require.Equal(t, http.StatusOK, rr.Code) 57 | nodes := []string{} 58 | err = json.Unmarshal(rr.Body.Bytes(), &nodes) 59 | require.Nil(t, err, err) 60 | require.Equal(t, 1, len(nodes)) 61 | 62 | // Noop an error on adding a node twice 63 | addNodeReq, _ = http.NewRequest("POST", "/nodes", bytes.NewBufferString(addNodePayload)) 64 | rr = httptest.NewRecorder() 65 | handler.ServeHTTP(rr, addNodeReq) 66 | require.Equal(t, http.StatusOK, rr.Code) 67 | require.Equal(t, 1, len(nodePool.nodes)) 68 | 69 | // Delete a non-existing node with DELETE /nodes request 70 | delNodePayload := `{"uri":"http://localhost:8545X"}` 71 | delNodeReq, _ := http.NewRequest("DELETE", "/nodes", bytes.NewBufferString(delNodePayload)) 72 | rr = httptest.NewRecorder() 73 | handler.ServeHTTP(rr, delNodeReq) 74 | require.Equal(t, http.StatusBadRequest, rr.Code) 75 | require.Equal(t, 1, len(nodePool.nodes)) 76 | 77 | // Delete a node with DELETE /nodes request 78 | delNodePayload = fmt.Sprintf(`{"uri":"%s"}`, mockNodeServer.URL) 79 | delNodeReq, _ = http.NewRequest("DELETE", "/nodes", bytes.NewBufferString(delNodePayload)) 80 | rr = httptest.NewRecorder() 81 | handler.ServeHTTP(rr, delNodeReq) 82 | require.Equal(t, http.StatusOK, rr.Code) 83 | require.Equal(t, 0, len(nodePool.nodes)) 84 | 85 | // check redis 86 | nodesFromRedis, err = redisTestState.GetNodes() 87 | require.Nil(t, err, err) 88 | require.Equal(t, 0, len(nodesFromRedis)) 89 | 90 | // Try to add an invalid node with POST /nodes request 91 | addNodePayload = `{"uri":"http://localhost:12354"}` 92 | addNodeReq, _ = http.NewRequest("POST", "/nodes", bytes.NewBufferString(addNodePayload)) 93 | rr = httptest.NewRecorder() 94 | handler.ServeHTTP(rr, addNodeReq) 95 | require.Equal(t, http.StatusBadRequest, rr.Code) 96 | require.Equal(t, 0, len(nodePool.nodes)) 97 | } 98 | 99 | func TestWebserverSim(t *testing.T) { 100 | mockNodeBackend := testutils.NewMockNodeBackend() 101 | mockNodeServer := httptest.NewServer(http.HandlerFunc(mockNodeBackend.Handler)) 102 | 103 | prioQueue := NewPrioQueue(0, 0, 0, 2, false) 104 | nodePool := NewNodePool(testLog, nil, 1) 105 | nodePool.AddNode(mockNodeServer.URL) 106 | webserver := NewWebserver(testLog, ":12345", prioQueue, nodePool) 107 | handler := http.HandlerFunc(webserver.HandleQueueRequest) 108 | 109 | // Pump jobs from prioQueue to nodepool 110 | go func() { 111 | for { 112 | job := prioQueue.Pop() 113 | if job == nil { 114 | return 115 | } 116 | nodePool.JobC <- job 117 | } 118 | }() 119 | 120 | // Test valid sim request 121 | reqPayload := testutils.NewJSONRPCRequest1(1, "eth_callBundle", "0x1") 122 | reqPayloadBytes, err := json.Marshal(reqPayload) 123 | require.Nil(t, err, err) 124 | getSimReq, _ := http.NewRequest("POST", "/", bytes.NewBuffer(reqPayloadBytes)) 125 | rr := httptest.NewRecorder() 126 | handler.ServeHTTP(rr, getSimReq) 127 | require.Equal(t, http.StatusOK, rr.Code) 128 | require.Equal(t, `{"id":1,"result":"cool","jsonrpc":"2.0"}`+"\n", rr.Body.String()) 129 | 130 | // Test node error handling 131 | mockNodeBackend.Reset() 132 | mockNodeBackend.HTTPHandlerOverride = func(w http.ResponseWriter, req *http.Request) { 133 | http.Error(w, "error", 479) 134 | } 135 | getSimReq, _ = http.NewRequest("POST", "/", bytes.NewBuffer(reqPayloadBytes)) 136 | rr = httptest.NewRecorder() 137 | handler.ServeHTTP(rr, getSimReq) 138 | require.Equal(t, 479, rr.Code) 139 | require.Equal(t, "error\n", rr.Body.String()) 140 | 141 | // Test request cancelling (using a custom backend handler override to wait for 5 seconds) 142 | mockNodeBackend.Reset() 143 | mockNodeBackend.RPCHandlerOverride = func(req *testutils.JSONRPCRequest) (result interface{}, err error) { 144 | time.Sleep(5 * time.Second) 145 | return nil, errors.New("timeout") 146 | } 147 | getSimReq, _ = http.NewRequest("POST", "/", bytes.NewBuffer(reqPayloadBytes)) 148 | ctx, cancel := context.WithCancel(context.Background()) 149 | getSimReqWithContext := getSimReq.WithContext(ctx) 150 | 151 | rr = httptest.NewRecorder() 152 | t1 := time.Now() 153 | doneC := make(chan bool) 154 | go func() { 155 | handler.ServeHTTP(rr, getSimReqWithContext) // would take 5 seconds without cancelling 156 | doneC <- true 157 | }() 158 | cancel() 159 | <-doneC 160 | tX := time.Since(t1) 161 | require.True(t, tX.Seconds() < 1, "should have been cancelled") 162 | // Here no further requests can be made! 163 | } 164 | -------------------------------------------------------------------------------- /staticcheck.conf: -------------------------------------------------------------------------------- 1 | checks = ["all"] 2 | # checks = ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022", "-ST1023"] 3 | initialisms = ["ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "QPS", "RAM", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "GID", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "SIP", "RTP", "AMQP", "DB", "TS"] 4 | dot_import_whitelist = ["github.com/mmcloughlin/avo/build", "github.com/mmcloughlin/avo/operand", "github.com/mmcloughlin/avo/reg"] 5 | http_status_code_whitelist = ["200", "400", "404", "500"] 6 | -------------------------------------------------------------------------------- /testutils/mockserver.go: -------------------------------------------------------------------------------- 1 | // Package testutils contains a mock execution backend (for testing and dev purposes) 2 | package testutils 3 | 4 | import ( 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "time" 10 | 11 | "go.uber.org/zap" 12 | ) 13 | 14 | var ( 15 | testLogger, _ = zap.NewDevelopment() 16 | testLog = testLogger.Sugar() 17 | ) 18 | 19 | type MockNodeBackend struct { 20 | LastRawRequest *http.Request 21 | LastJSONRPCRequest *JSONRPCRequest 22 | LastJSONRPCRequestTimestamp time.Time 23 | RPCHandlerOverride func(req *JSONRPCRequest) (result interface{}, err error) 24 | HTTPHandlerOverride func(w http.ResponseWriter, req *http.Request) 25 | } 26 | 27 | func NewMockNodeBackend() *MockNodeBackend { 28 | return &MockNodeBackend{} 29 | } 30 | 31 | func (be *MockNodeBackend) Reset() { 32 | be.LastRawRequest = nil 33 | be.LastJSONRPCRequest = nil 34 | be.LastJSONRPCRequestTimestamp = time.Time{} 35 | be.RPCHandlerOverride = nil 36 | be.HTTPHandlerOverride = nil 37 | } 38 | 39 | func (be *MockNodeBackend) handleRPCRequest(req *JSONRPCRequest) (result interface{}, err error) { 40 | if be.RPCHandlerOverride != nil { 41 | return be.RPCHandlerOverride(req) 42 | } 43 | 44 | be.LastJSONRPCRequest = req 45 | 46 | switch req.Method { 47 | case "net_version": 48 | return "1", nil 49 | case "eth_callBundle": 50 | return "cool", nil 51 | } 52 | 53 | return "", fmt.Errorf("no RPC method handler implemented for %s", req.Method) 54 | } 55 | 56 | func (be *MockNodeBackend) Handler(w http.ResponseWriter, req *http.Request) { 57 | if be.HTTPHandlerOverride != nil { 58 | be.HTTPHandlerOverride(w, req) 59 | return 60 | } 61 | 62 | defer req.Body.Close() 63 | be.LastRawRequest = req 64 | be.LastJSONRPCRequestTimestamp = time.Now() 65 | 66 | testLog.Debugw("mockserver call", "remoteAddr", req.RemoteAddr, "method", req.Method, "url", req.URL) 67 | 68 | w.Header().Set("Content-Type", "application/json") 69 | testHeader := req.Header.Get("Test") 70 | w.Header().Set("Test", testHeader) 71 | 72 | returnError := func(id interface{}, msg string) { 73 | testLog.Debug("MockNodeBackend: returnError", "msg", msg) 74 | res := JSONRPCResponse{ 75 | ID: id, 76 | Error: &JSONRPCError{ 77 | Code: -32603, 78 | Message: msg, 79 | }, 80 | } 81 | 82 | if err := json.NewEncoder(w).Encode(res); err != nil { 83 | testLog.Debug("MockNodeBackend: error writing returnError response", "error", err, "response", res) 84 | } 85 | } 86 | 87 | body, err := io.ReadAll(req.Body) 88 | if err != nil { 89 | returnError(-1, fmt.Sprintf("failed to read request body: %v", err)) 90 | return 91 | } 92 | 93 | // Parse JSON RPC 94 | jsonReq := new(JSONRPCRequest) 95 | if err = json.Unmarshal(body, &jsonReq); err != nil { 96 | returnError(-1, fmt.Sprintf("failed to parse JSON RPC request: %v", err)) 97 | return 98 | } 99 | 100 | rawRes, err := be.handleRPCRequest(jsonReq) 101 | if err != nil { 102 | returnError(jsonReq.ID, err.Error()) 103 | return 104 | } 105 | 106 | w.WriteHeader(http.StatusOK) 107 | resBytes, err := json.Marshal(rawRes) 108 | if err != nil { 109 | fmt.Println("error mashalling rawRes:", rawRes, err) 110 | } 111 | 112 | res := NewJSONRPCResponse(jsonReq.ID, resBytes) 113 | 114 | // Write to client request 115 | if err := json.NewEncoder(w).Encode(res); err != nil { 116 | testLog.Error("error writing response", "error", err, "data", rawRes) 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /testutils/mockserver_test.go: -------------------------------------------------------------------------------- 1 | package testutils 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestMockServer(t *testing.T) { 12 | mockNodeBackend := NewMockNodeBackend() 13 | mockNodeServer := httptest.NewServer(http.HandlerFunc(mockNodeBackend.Handler)) 14 | 15 | resp, err := http.PostForm(mockNodeServer.URL, nil) 16 | require.Nil(t, err, err) 17 | require.Equal(t, http.StatusOK, resp.StatusCode) 18 | 19 | mockNodeBackend.Reset() 20 | mockNodeBackend.HTTPHandlerOverride = func(w http.ResponseWriter, req *http.Request) { 21 | http.Error(w, "error", 479) 22 | } 23 | 24 | resp, err = http.PostForm(mockNodeServer.URL, nil) 25 | require.Nil(t, err, err) 26 | require.Equal(t, 479, resp.StatusCode) 27 | } 28 | -------------------------------------------------------------------------------- /testutils/types.go: -------------------------------------------------------------------------------- 1 | package testutils 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | ) 7 | 8 | type JSONRPCRequest struct { 9 | ID interface{} `json:"id"` 10 | Method string `json:"method"` 11 | Params []interface{} `json:"params"` 12 | Version string `json:"jsonrpc,omitempty"` 13 | } 14 | 15 | func NewJSONRPCRequest(id interface{}, method string, params []interface{}) *JSONRPCRequest { 16 | return &JSONRPCRequest{ 17 | ID: id, 18 | Method: method, 19 | Params: params, 20 | Version: "2.0", 21 | } 22 | } 23 | 24 | func NewJSONRPCRequest1(id interface{}, method string, param interface{}) *JSONRPCRequest { 25 | return NewJSONRPCRequest(id, method, []interface{}{param}) 26 | } 27 | 28 | type JSONRPCResponse struct { 29 | ID interface{} `json:"id"` 30 | Result json.RawMessage `json:"result"` 31 | Error *JSONRPCError `json:"error,omitempty"` 32 | Version string `json:"jsonrpc"` 33 | } 34 | 35 | func NewJSONRPCResponse(id interface{}, result json.RawMessage) *JSONRPCResponse { 36 | return &JSONRPCResponse{ 37 | ID: id, 38 | Result: result, 39 | Version: "2.0", 40 | } 41 | } 42 | 43 | // JSONRPCError as per the spec: https://www.jsonrpc.org/specification#error_object 44 | type JSONRPCError struct { 45 | Code int `json:"code"` 46 | Message string `json:"message"` 47 | } 48 | 49 | func (err JSONRPCError) Error() string { 50 | return fmt.Sprintf("Error %d (%s)", err.Code, err.Message) 51 | } 52 | --------------------------------------------------------------------------------