├── .env ├── .github ├── CODEOWNERS ├── mergeable.yml └── workflows │ ├── arm.yml │ └── ci.yml ├── .gitignore ├── CONTRIBUTORS.md ├── Dockerfile ├── Dockerfile-load-generator ├── Dockerfile-mysql ├── LICENSE ├── Makefile ├── README.md ├── bench_test.go ├── ctldb-example.sql ├── docker-compose-example.yml ├── docker-compose.yml ├── examples ├── basic │ └── basic.go └── simple_client.go ├── ext_reader_test.go ├── go.mod ├── go.sum ├── initialize.go ├── ldb.go ├── ldb_reader.go ├── ldb_reader_test.go ├── ldb_rotating_reader.go ├── ldb_rotating_reader_test.go ├── ldb_testing.go ├── ldb_testing_test.go ├── marshal_test.go ├── pkg ├── changelog │ ├── changelog_writer.go │ └── changelog_writer_test.go ├── cmd │ ├── ctlstore-cli │ │ ├── cmd │ │ │ ├── add_fields.go │ │ │ ├── create_family.go │ │ │ ├── create_table.go │ │ │ ├── flags.go │ │ │ ├── read_keys.go │ │ │ ├── read_keys_test.go │ │ │ ├── read_seq.go │ │ │ ├── root.go │ │ │ ├── table_limits.go │ │ │ ├── table_limits_test.go │ │ │ ├── utils.go │ │ │ └── writer_limits.go │ │ └── main.go │ ├── ctlstore-mutator │ │ └── main.go │ └── ctlstore │ │ └── main.go ├── ctldb │ ├── ctldb.go │ ├── ctldb_test_helpers.go │ └── dsn_parameters.go ├── errs │ └── errs.go ├── event │ ├── changelog.go │ ├── changelog_test.go │ ├── entry.go │ ├── event.go │ ├── fake_changelog.go │ ├── fake_log_writer.go │ ├── iterator.go │ ├── iterator_integration_test.go │ └── iterator_test.go ├── executive │ ├── db_executive.go │ ├── db_executive_test.go │ ├── db_info.go │ ├── db_limiter.go │ ├── db_limiter_test.go │ ├── dml_ledger_writer.go │ ├── dml_ledger_writer_test.go │ ├── executive.go │ ├── executive_endpoint.go │ ├── executive_endpoint_test.go │ ├── executive_service.go │ ├── fake_time.go │ ├── fakes │ │ └── executive_interface.go │ ├── generate.go │ ├── health.go │ ├── mutators_store.go │ ├── mutators_store_test.go │ ├── sql.go │ ├── sql_test.go │ ├── status_writer.go │ ├── table_sizer.go │ ├── table_sizer_test.go │ ├── test_executive.go │ └── test_executive_test.go ├── fakes │ └── s3_client.go ├── globalstats │ ├── stats.go │ └── stats_test.go ├── heartbeat │ ├── heartbeat.go │ └── heartbeat_test.go ├── ldb │ └── ldbs.go ├── ldbwriter │ ├── changelog_callback.go │ ├── ldb_callback_writer.go │ ├── ldb_writer.go │ ├── ldb_writer_test.go │ └── ldb_writer_with_changelog.go ├── ledger │ ├── ecs_client.go │ ├── ecs_metadata.go │ ├── fake_ticker.go │ ├── fake_ticker_test.go │ ├── fakes │ │ └── ecs_client.go │ ├── generate.go │ ├── ledger_monitor.go │ ├── ledger_monitor_test.go │ └── opts.go ├── limits │ ├── limits.go │ └── limits_test.go ├── logwriter │ ├── sized_log_writer.go │ └── sized_log_writer_test.go ├── mysql │ └── mysql_info.go ├── reflector │ ├── bootstrap_test.go │ ├── dml_source.go │ ├── dml_source_test.go │ ├── download.go │ ├── download_test.go │ ├── fakes │ │ └── fake_reflector.go │ ├── generate.go │ ├── jitter.go │ ├── pipeline_test.go │ ├── reflector.go │ ├── reflector_ctl.go │ ├── reflector_ctl_test.go │ ├── reflector_test.go │ ├── s3_client.go │ ├── shovel.go │ ├── shovel_test.go │ ├── wal_monitor.go │ └── wal_monitor_test.go ├── scanfunc │ ├── marshal.go │ ├── noop_scanner.go │ └── scan_func.go ├── schema │ ├── db_column_info.go │ ├── db_column_meta.go │ ├── dml.go │ ├── family_name.go │ ├── family_table.go │ ├── field_name.go │ ├── field_type.go │ ├── ldb_table_name.go │ ├── named_field_type.go │ ├── params.go │ ├── primary_key.go │ ├── table.go │ ├── table_name.go │ ├── validate_test.go │ └── writer_name.go ├── sidecar │ ├── sidecar.go │ └── sidecar_test.go ├── sqlgen │ ├── sqlgen.go │ └── sqlgen_test.go ├── sqlite │ ├── driver.go │ ├── sql_change_buffer.go │ ├── sql_change_buffer_test.go │ ├── sqlite_info.go │ ├── sqlite_watch.go │ └── sqlite_watch_test.go ├── supervisor │ ├── archived_snapshot.go │ ├── fake_read_closer.go │ ├── fakes │ │ └── s3_uploader.go │ ├── generate.go │ ├── gzip_pipe.go │ ├── gzip_pipe_test.go │ ├── s3_snapshot_test.go │ ├── s3_uploader.go │ ├── supervisor.go │ └── supervisor_test.go ├── tests │ ├── tests.go │ └── tests_test.go ├── units │ └── units.go ├── unsafe │ ├── unsafe.go │ └── unsafe_test.go ├── utils │ ├── atomic_bool.go │ ├── doc.go │ ├── ensure_dir.go │ ├── interface_slice.go │ ├── json_reader.go │ ├── looper.go │ ├── teardowns.go │ └── teardowns_test.go └── version │ ├── version.go │ └── version_go1_12.go ├── rows.go ├── scripts └── download.sh ├── tools.go └── version.go /.env: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/segmentio/ctlstore/9bb6b760151d7ae56f3d2217bc9b029f44f91dfd/.env -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @segmentio/team-cloud-config 2 | -------------------------------------------------------------------------------- /.github/mergeable.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | mergeable: 3 | - when: pull_request.*, pull_request_review.* 4 | name: Change Control Pre-Merge-Check 5 | validate: 6 | - do: or 7 | validate: 8 | - do: and 9 | validate: 10 | - do: approvals 11 | min: 12 | count: 1 13 | - do: description 14 | or: 15 | - and: 16 | - must_exclude: 17 | regex: Testing completed successfully 18 | - must_include: 19 | regex: Testing not required 20 | - and: 21 | - must_include: 22 | regex: Testing completed successfully 23 | - must_exclude: 24 | regex: Testing not required 25 | - must_include: 26 | regex: 'CC-\d{4,5}' 27 | - do: title 28 | must_include: 29 | regex: stage|staging|README|non-prod|docs 30 | pass: 31 | - do: checks 32 | status: success 33 | payload: 34 | title: Mergeable Run has been Completed! 35 | summary: All the validators are passing! 36 | fail: 37 | - do: checks 38 | status: failure 39 | payload: 40 | title: Mergeable Run has been Completed! 41 | summary: "### Status: {{toUpperCase validationStatus}}\ 42 | \nHere are some stats of the run:\ 43 | \n{{#with validationSuites.[0]}} {{ validations.length }} validations were ran. {{/with}}\n" 44 | text: "{{#each validationSuites}}\n 45 | ### {{{statusIcon status}}} Change-Control Pre-Merge Check \n 46 | #### All PRs must follow bellow Change-Control rules: \n 47 | * ##### {{#with validations.[0]}} {{{statusIcon status}}} Must have at least one approval.\n {{/with}} 48 | * ##### {{#with validations.[1]}} {{{statusIcon status}}} Description includes a testing plan: \n 49 | \t ##### \"Testing not required\" OR \"Testing completed successfully\" but NOT BOTH. \n 50 | \t ##### OR \n 51 | \t ##### Jira Change-Control ticket is included.\n {{/with}}\n\n 52 | #### PRs that are exempt from Change-Control: \n 53 | * ##### {{#with validations.[2]}} {{{statusIcon status}}} Title includes stage, staging, README, non-prod, docs.\n {{/with}}\n 54 | {{/each}}" 55 | -------------------------------------------------------------------------------- /.github/workflows/arm.yml: -------------------------------------------------------------------------------- 1 | name: arm-build 2 | on: 3 | push: 4 | branches: ["master"] 5 | pull_request: 6 | branches: ["**"] 7 | jobs: 8 | 9 | publish-arm-production: 10 | runs-on: ubuntu-latest 11 | permissions: 12 | contents: read 13 | packages: write 14 | if: ${{ (github.ref_name == 'master') && (github.event_name == 'push')}} 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: setup env variables 18 | id: vars 19 | run: | 20 | SHA=${GITHUB_SHA:0:7} 21 | echo "SHA=$SHA" >> $GITHUB_ENV 22 | echo "IMAGE=ghcr.io/segmentio/ctlstore:$SHA-arm" >> $GITHUB_ENV 23 | 24 | - name: "Image Name" 25 | run: echo "publishing ${IMAGE}" 26 | 27 | - name: Log in to the Container registry 28 | uses: docker/login-action@v2 29 | with: 30 | registry: ghcr.io 31 | username: ${{ github.actor }} 32 | password: ${{ secrets.GITHUB_TOKEN }} 33 | logout: true 34 | 35 | - name: Build and push image for master 36 | run: | 37 | docker context create buildx-build 38 | docker buildx create --use buildx-build 39 | docker buildx build \ 40 | --platform=linux/arm64 \ 41 | -t ${IMAGE} \ 42 | --build-arg VERSION=${SHA} \ 43 | --push \ 44 | . 45 | 46 | publish-arm-pr: 47 | runs-on: ubuntu-latest 48 | permissions: 49 | contents: read 50 | packages: write 51 | if: ${{ (github.event_name == 'pull_request') }} 52 | steps: 53 | - uses: actions/checkout@v3 54 | - name: setup env variables 55 | id: vars 56 | run: | 57 | SHA=$(git rev-parse --short ${{ github.event.pull_request.head.sha }}) 58 | echo "SHA=$SHA" >> $GITHUB_ENV 59 | echo "IMAGE=ghcr.io/segmentio/ctlstore:$(echo ${GITHUB_HEAD_REF:0:116} | sed 's/[^a-zA-Z0-9]/-/g' )-$SHA-arm" >> $GITHUB_ENV 60 | 61 | - name: "Image Name" 62 | run: echo "publishing ${IMAGE}" 63 | 64 | - name: Log in to the Container registry 65 | uses: docker/login-action@v2 66 | with: 67 | registry: ghcr.io 68 | username: ${{ github.actor }} 69 | password: ${{ secrets.GITHUB_TOKEN }} 70 | logout: true 71 | 72 | - name: Build and push image for pull request 73 | run: | 74 | docker context create buildx-build 75 | docker buildx create --use buildx-build 76 | docker buildx build \ 77 | --platform=linux/arm64 \ 78 | -t ${IMAGE} \ 79 | --build-arg VERSION=${SHA} \ 80 | --push \ 81 | . 82 | - run: echo "GHCR PUBLISH SUCCESSFUL" 83 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: 3 | push: 4 | branches: ["master"] 5 | pull_request: 6 | branches: ["**"] 7 | jobs: 8 | build: 9 | name: Executes a full build 10 | runs-on: ubuntu-latest 11 | services: 12 | mysql: 13 | image: mysql:5.6 14 | env: 15 | MYSQL_ROOT_PASSWORD: ctldbpw 16 | MYSQL_DATABASE: ctldb 17 | MYSQL_USER: ctldb 18 | MYSQL_PASSWORD: ctldbpw 19 | ports: 20 | - 3306:3306 21 | 22 | steps: 23 | - name: checkout 24 | uses: actions/checkout@v3 25 | 26 | - name: setup go 1.20 27 | uses: actions/setup-go@v3 28 | with: 29 | go-version: '1.20' 30 | 31 | - name: Deps 32 | run: | 33 | make deps 34 | 35 | - name: Test 36 | run: | 37 | make test 38 | 39 | - name: build 40 | run: | 41 | make build 42 | 43 | publish-amd-production: 44 | needs: [ build ] 45 | runs-on: ubuntu-latest 46 | permissions: 47 | contents: read 48 | packages: write 49 | if: ${{ (github.ref_name == 'master') && (github.event_name == 'push')}} 50 | steps: 51 | - uses: actions/checkout@v3 52 | - name: setup env variables 53 | id: vars 54 | run: | 55 | SHA=${GITHUB_SHA:0:7} 56 | echo "SHA=$SHA" >> $GITHUB_ENV 57 | echo "IMAGE=ghcr.io/segmentio/ctlstore:$SHA" >> $GITHUB_ENV 58 | 59 | - name: "Image Name" 60 | run: echo "publishing ${IMAGE}" 61 | 62 | - name: Log in to the Container registry 63 | uses: docker/login-action@v2 64 | with: 65 | registry: ghcr.io 66 | username: ${{ github.actor }} 67 | password: ${{ secrets.GITHUB_TOKEN }} 68 | logout: true 69 | 70 | - name: Build and push image for master 71 | run: | 72 | docker context create buildx-build 73 | docker buildx create --use buildx-build 74 | docker buildx build \ 75 | --platform=linux/amd64 \ 76 | -t ${IMAGE} \ 77 | --build-arg VERSION=${SHA} \ 78 | --push \ 79 | . 80 | 81 | publish-amd-pr: 82 | needs: [ build ] 83 | runs-on: ubuntu-latest 84 | permissions: 85 | contents: read 86 | packages: write 87 | if: ${{ (github.event_name == 'pull_request') }} 88 | steps: 89 | - uses: actions/checkout@v3 90 | - name: setup env variables 91 | id: vars 92 | run: | 93 | SHA=$(git rev-parse --short ${{ github.event.pull_request.head.sha }}) 94 | echo "SHA=$SHA" >> $GITHUB_ENV 95 | echo "IMAGE=ghcr.io/segmentio/ctlstore:$(echo ${GITHUB_HEAD_REF:0:119} | sed 's/[^a-zA-Z0-9]/-/g' )-$SHA" >> $GITHUB_ENV 96 | 97 | - name: "Image Name" 98 | run: echo "publishing ${IMAGE}" 99 | 100 | - name: Log in to the Container registry 101 | uses: docker/login-action@v2 102 | with: 103 | registry: ghcr.io 104 | username: ${{ github.actor }} 105 | password: ${{ secrets.GITHUB_TOKEN }} 106 | logout: true 107 | 108 | - name: Build and push image for pull request 109 | run: | 110 | docker context create buildx-build 111 | docker buildx create --use buildx-build 112 | docker buildx build \ 113 | --platform=linux/amd64 \ 114 | -t ${IMAGE} \ 115 | --build-arg VERSION=${SHA} \ 116 | --push \ 117 | . 118 | - run: echo "GHCR PUBLISH SUCCESSFUL" 119 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | /ctlstore 3 | /vendor 4 | /bin 5 | .coverprofile 6 | .vscode 7 | .idea -------------------------------------------------------------------------------- /CONTRIBUTORS.md: -------------------------------------------------------------------------------- 1 | # Contributors 2 | 3 | ## How can I help? 4 | 5 | First, thank you for contributing! 6 | 7 | For small things like typos, or trivial bug fixes, please feel free to submit a PR. 8 | Code changes should be accompanied by tests where it makes sense. 9 | 10 | For anything larger than that, please file an issue first describing the change 11 | you'd like to make so it can be discussed. We want to make sure your time is 12 | respected. If for whatever reason we decide a change can't be merged in, we'd 13 | prefer that is decided before you put a lot of effort into the change. 14 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.20-alpine 2 | ENV SRC github.com/segmentio/ctlstore 3 | ARG VERSION 4 | 5 | RUN apk --update add gcc git curl alpine-sdk libc6-compat ca-certificates sqlite \ 6 | && curl -SsL https://github.com/segmentio/chamber/releases/download/v2.13.2/chamber-v2.13.2-linux-amd64 -o /bin/chamber \ 7 | && curl -sL https://github.com/peak/s5cmd/releases/download/v2.1.0/s5cmd_2.1.0_Linux-64bit.tar.gz -o s5cmd.gz && tar -xzf s5cmd.gz -C /bin \ 8 | && chmod +x /bin/chamber \ 9 | && chmod +x /bin/s5cmd 10 | 11 | 12 | COPY . /go/src/${SRC} 13 | WORKDIR /go/src/${SRC} 14 | RUN go mod vendor 15 | RUN CGO_ENABLED=1 go install -ldflags="-X github.com/segmentio/ctlstore/pkg/version.version=$VERSION" ${SRC}/pkg/cmd/ctlstore \ 16 | && cp ${GOPATH}/bin/ctlstore /usr/local/bin 17 | 18 | RUN CGO_ENABLED=1 go install -ldflags="-X github.com/segmentio/ctlstore/pkg/version.version=$VERSION" ${SRC}/pkg/cmd/ctlstore-cli \ 19 | && cp ${GOPATH}/bin/ctlstore-cli /usr/local/bin 20 | 21 | FROM alpine 22 | RUN apk --no-cache add sqlite pigz aws-cli perl-utils jq 23 | 24 | COPY --from=0 /go/src/github.com/segmentio/ctlstore/scripts/download.sh . 25 | COPY --from=0 /bin/chamber /bin/chamber 26 | COPY --from=0 /bin/s5cmd /bin/s5cmd 27 | COPY --from=0 /usr/local/bin/ctlstore /usr/local/bin/ 28 | COPY --from=0 /usr/local/bin/ctlstore-cli /usr/local/bin/ 29 | -------------------------------------------------------------------------------- /Dockerfile-load-generator: -------------------------------------------------------------------------------- 1 | FROM golang:1.15 2 | COPY . /go/src/github.com/segmentio/ctlstore/ 3 | RUN go install github.com/segmentio/ctlstore/pkg/cmd/ctlstore-mutator 4 | ENTRYPOINT /go/bin/ctlstore-mutator 5 | -------------------------------------------------------------------------------- /Dockerfile-mysql: -------------------------------------------------------------------------------- 1 | FROM mysql:5.6 2 | COPY ctldb-example.sql /docker-entrypoint-initdb.d/ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Segment.io, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | VERSION := $(shell git describe --tags --always --dirty="-dev") 2 | LDFLAGS := -ldflags='-X "github.com/segmentio/ctlstore/pkg/version.version=$(VERSION)"' 3 | DOCKER_REPO := 528451384384.dkr.ecr.us-west-2.amazonaws.com/ctlstore 4 | Q= 5 | 6 | GOTESTFLAGS = -race -count 1 7 | 8 | export GO111MODULE?=on 9 | 10 | .PHONY: deps 11 | deps: 12 | $Qgo get -d ./... 13 | 14 | .PHONY: vendor 15 | vendor: 16 | $Qgo mod vendor 17 | 18 | .PHONY: clean 19 | clean: 20 | $Qrm -rf vendor/ && git checkout ./vendor && dep ensure 21 | 22 | .PHONY: install 23 | install: 24 | $Qgo install ./pkg/cmd/ctlstore 25 | 26 | .PHONY: build 27 | build: deps 28 | $Qgo build -ldflags="-X github.com/segmentio/ctlstore/pkg/version.version=${VERSION} -X github.com/segmentio/ctlstore/pkg/globalstats.version=${VERSION}" -o ./bin/ctlstore ./pkg/cmd/ctlstore 29 | 30 | .PHONY: docker 31 | docker: 32 | $Qdocker build --build-arg VERSION=$(VERSION) \ 33 | -t $(DOCKER_REPO):$(VERSION) \ 34 | . 35 | 36 | .PHONY: releasecheck 37 | releasecheck: 38 | $Qexit $(shell git status --short | wc -l) 39 | 40 | .PHONY: release-nonmaster 41 | release-nonmaster: docker 42 | $Qdocker push $(DOCKER_REPO):$(VERSION) 43 | 44 | .PHONY: release 45 | release: docker 46 | $Qdocker tag $(DOCKER_REPO):$(VERSION) $(DOCKER_REPO):latest 47 | $Qdocker push $(DOCKER_REPO):$(VERSION) 48 | $Qdocker push $(DOCKER_REPO):latest 49 | 50 | .PHONY: release-stable 51 | release-stable: docker 52 | $Qdocker tag $(DOCKER_REPO):$(VERSION) $(DOCKER_REPO):stable 53 | $Qdocker push $(DOCKER_REPO):stable 54 | 55 | .PHONY: vet 56 | vet: 57 | $Qgo vet ./... 58 | 59 | .PHONY: generate 60 | generate: 61 | $Qgo generate ./... 62 | 63 | .PHONY: fmtcheck 64 | fmtchk: 65 | $Qexit $(shell gofmt -l . | grep -v '^vendor' | wc -l) 66 | 67 | .PHONY: fmtfix 68 | fmtfix: 69 | $Qgofmt -w $(shell find . -iname '*.go' | grep -v vendor) 70 | 71 | .PHONY: test 72 | test: 73 | $Qgo test $(GOTESTFLAGS) ./... 74 | 75 | .PHONY: bench 76 | bench: 77 | $Qgo test $(GOTESTFLAGS) -bench . 78 | 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ctlstore 2 | 3 | ctlstore is a distributed data store that provides very low latency, 4 | always-available, "infinitely" scalable reads. The underlying mechanism for 5 | this is a SQLite database that runs on every host called the LDB. A daemon 6 | called the Reflector plays logged writes from a central database into the LDB. 7 | As this involves replicating the full data store on every host, it is only 8 | practical for situations where the write rate (<100/s total) and data volumes 9 | (<10GB total) are low. 10 | 11 | Recommended reading: 12 | 13 | * [Rick Branson's blog post on ctlstore](https://segment.com/blog/separating-our-data-and-control-planes-with-ctlstore/) 14 | * [Calvin talking about ctlstore at Synapse](https://vimeo.com/293246627) 15 | 16 | ## Security 17 | 18 | Note that because ctlstore replicates the central database to an LDB on each host 19 | with a reflector, that LDB contains all of the control data. In its current state 20 | that means that any application which has access to the LDB can access all of the 21 | data within it. 22 | 23 | The implications of this are that you should not store data in ctlstore that should 24 | only be accessed by a subset of the applications that can read the LDB. Things like 25 | secrets, passwords, and so on, are an example of this. 26 | 27 | The ctlstore system is meant to store non-sensitive configuration data. 28 | 29 | ## Development 30 | 31 | A MySQL database is needed to run the tests, which can be started using Docker Compose: 32 | 33 | ``` 34 | $ docker-compose up -d 35 | ``` 36 | 37 | Run the tests using make: 38 | 39 | ``` 40 | $ make test 41 | # For more verbosity (`Q=` trick applies to all targets) 42 | $ make test Q= 43 | ``` 44 | 45 | A single `ctlstore` binary is used for all functionality. Build it with make: 46 | 47 | ``` 48 | $ make build 49 | ``` 50 | 51 | Sync non-stdlib dependencies and pull them into `./vendor` 52 | 53 | ``` 54 | $ make deps 55 | ``` 56 | 57 | Ctlstore uses Go modules. To build a docker image, the dependencies must be vendored 58 | first: 59 | 60 | ``` 61 | $ make vendor 62 | ``` 63 | 64 | Many of ctlstore's unit tests use mocks. To regenerate the mocks using [counterfeiter](https://github.com/maxbrunsfeld/counterfeiter): 65 | 66 | ``` 67 | $ make generate 68 | ``` 69 | 70 | ## Tying the Pieces Together 71 | 72 | This project includes a docker-compose file `docker-compose-example.yml`. This initializes and runs 73 | 74 | * mysql (ctlstore SoR) 75 | * executive service (guards the ctlstore SoR) 76 | * reflector (builds the LDB) 77 | * heartbeat (mutates a ctlstore table periodically) 78 | * sidecar (provides HTTP API access to ctlstore reader API) 79 | * supervisor (periodically snapshots LDB) 80 | 81 | To start it, run: 82 | 83 | ``` 84 | $ make deps 85 | $ make vendor 86 | $ docker-compose -f docker-compose-example.yml up -d 87 | ``` 88 | -------------------------------------------------------------------------------- /bench_test.go: -------------------------------------------------------------------------------- 1 | package ctlstore 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "io/ioutil" 7 | "os" 8 | "path/filepath" 9 | "testing" 10 | 11 | "github.com/segmentio/ctlstore/pkg/ldb" 12 | "github.com/segmentio/ctlstore/pkg/tests" 13 | ) 14 | 15 | func BenchmarkLDBQueryBaseline(b *testing.B) { 16 | b.StopTimer() 17 | 18 | ctx := context.TODO() 19 | 20 | tmpDir, err := ioutil.TempDir("", "") 21 | if err != nil { 22 | b.Fatalf("Unexpected error creating temp dir: %v", err) 23 | } 24 | defer os.RemoveAll(tmpDir) 25 | 26 | ldbPath := filepath.Join(tmpDir, "tmp.db") 27 | localDB, err := sql.Open("sqlite3", ldbPath) 28 | if err != nil { 29 | b.Fatalf("Unexpected error opening LDB: %v", err) 30 | } 31 | defer localDB.Close() 32 | 33 | err = ldb.EnsureLdbInitialized(ctx, localDB) 34 | if err != nil { 35 | b.Fatalf("Unexpected error initializing LDB: %v", err) 36 | } 37 | 38 | _, err = localDB.ExecContext(ctx, ` 39 | CREATE TABLE foo___bar ( 40 | key VARCHAR PRIMARY KEY, 41 | val VARCHAR 42 | ); 43 | INSERT INTO foo___bar VALUES('foo', 'bar'); 44 | `) 45 | if err != nil { 46 | b.Fatalf("Unexpected error inserting value into LDB: %v", err) 47 | } 48 | 49 | prepQ, err := localDB.PrepareContext(ctx, "SELECT * FROM foo___bar WHERE key = ?") 50 | if err != nil { 51 | b.Fatalf("Unexpected error preparing query: %v", err) 52 | } 53 | 54 | b.StartTimer() 55 | 56 | for i := 0; i < b.N; i++ { 57 | rows, err := prepQ.QueryContext(ctx, "foo") 58 | if err != nil { 59 | b.Errorf("Error querying: %v", err) 60 | continue 61 | } 62 | for rows.Next() { 63 | var keyReceiver, valReceiver string 64 | err = rows.Scan(&keyReceiver, &valReceiver) 65 | if err != nil { 66 | b.Errorf("Error scanning: %v", err) 67 | continue 68 | } 69 | if keyReceiver != "foo" { 70 | b.Errorf("Received unexpected key: %v", keyReceiver) 71 | } 72 | if valReceiver != "bar" { 73 | b.Errorf("Received unexpected val: %v", valReceiver) 74 | } 75 | } 76 | } 77 | } 78 | 79 | func BenchmarkGetRowByKey(b *testing.B) { 80 | ctx := context.TODO() 81 | 82 | type benchContext struct { 83 | ldb *sql.DB 84 | ctx context.Context 85 | r *LDBReader 86 | } 87 | 88 | type benchKVRow struct { 89 | Key string `ctlstore:"key"` 90 | Value string `ctlstore:"val"` 91 | } 92 | 93 | testTmpDir, teardown := tests.WithTmpDir(b) 94 | defer teardown() 95 | 96 | ldbPath := filepath.Join(testTmpDir, "tmp.db") 97 | localDB, err := sql.Open("sqlite3", ldbPath) 98 | if err != nil { 99 | b.Fatalf("Unexpected error opening LDB: %v", err) 100 | } 101 | 102 | err = ldb.EnsureLdbInitialized(ctx, localDB) 103 | if err != nil { 104 | b.Fatalf("Unexpected error initializing LDB: %v", err) 105 | } 106 | 107 | _, err = localDB.ExecContext(ctx, ` 108 | CREATE TABLE foo___bar ( 109 | key VARCHAR PRIMARY KEY, 110 | val VARCHAR 111 | ); 112 | INSERT INTO foo___bar VALUES('foo', 'bar'); 113 | `) 114 | if err != nil { 115 | b.Fatalf("Unexpected error inserting value into LDB: %v", err) 116 | } 117 | 118 | r := NewLDBReaderFromDB(localDB) 119 | 120 | benchSetup := &benchContext{ 121 | ldb: localDB, 122 | ctx: ctx, 123 | r: r, 124 | } 125 | 126 | for i := 0; i < b.N; i++ { 127 | var row benchKVRow 128 | found, err := benchSetup.r.GetRowByKey(benchSetup.ctx, &row, "foo", "bar", "foo") 129 | if err != nil { 130 | b.Fatalf("Unexpected error calling GetRowByKey: %v", err) 131 | } 132 | if !found { 133 | b.Fatal("Should have found a row") 134 | } 135 | if row.Key != "foo" { 136 | b.Fatalf("Unexpected value in row key: %v", row.Key) 137 | } 138 | if row.Value != "bar" { 139 | b.Fatalf("Unexpected value in row val: %v", row.Value) 140 | } 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /ctldb-example.sql: -------------------------------------------------------------------------------- 1 | USE ctldb; 2 | 3 | ALTER DATABASE CHARACTER SET = utf8mb4 COLLATE = utf8mb4_unicode_ci; 4 | 5 | DROP TABLE IF EXISTS families; 6 | CREATE TABLE families ( 7 | id INTEGER AUTO_INCREMENT PRIMARY KEY, 8 | name VARCHAR(191) NOT NULL, 9 | UNIQUE KEY name (name) 10 | ); 11 | 12 | DROP TABLE IF EXISTS mutators; 13 | CREATE TABLE mutators ( 14 | writer VARCHAR(191) NOT NULL PRIMARY KEY, 15 | secret VARCHAR(255) NOT NULL, 16 | cookie BLOB(1024) NOT NULL, 17 | clock BIGINT NOT NULL DEFAULT 0 18 | ); 19 | 20 | DROP TABLE IF EXISTS ctlstore_dml_ledger; 21 | CREATE TABLE ctlstore_dml_ledger ( 22 | seq INTEGER AUTO_INCREMENT PRIMARY KEY, 23 | leader_ts DATETIME DEFAULT CURRENT_TIMESTAMP, 24 | statement MEDIUMTEXT NOT NULL 25 | ); 26 | 27 | DROP TABLE IF EXISTS locks; 28 | CREATE TABLE locks ( 29 | id VARCHAR(191) NOT NULL PRIMARY KEY, 30 | clock BIGINT NOT NULL DEFAULT 0 31 | ); 32 | 33 | INSERT INTO locks VALUES('ledger', 0); 34 | 35 | DROP TABLE IF EXISTS max_table_sizes; 36 | CREATE TABLE max_table_sizes ( 37 | family_name VARCHAR(30) NOT NULL, /* limit pulled from validate.go */ 38 | table_name VARCHAR(50) NOT NULL, /* limit pulled from validate.go */ 39 | warn_size_bytes BIGINT NOT NULL DEFAULT 0, 40 | max_size_bytes BIGINT NOT NULL DEFAULT 0, 41 | PRIMARY KEY (family_name, table_name) 42 | ); 43 | 44 | DROP TABLE IF EXISTS max_writer_rates; 45 | CREATE TABLE max_writer_rates ( 46 | writer_name VARCHAR(50) NOT NULL, /* limit pulled from validate.go */ 47 | max_rows_per_minute BIGINT NOT NULL , 48 | PRIMARY KEY (writer_name) 49 | ); 50 | 51 | DROP TABLE IF EXISTS writer_usage; 52 | CREATE TABLE writer_usage ( 53 | writer_name VARCHAR(50) NOT NULL, /* limit pulled from validate.go */ 54 | bucket BIGINT NOT NULL, 55 | amount BIGINT NOT NULL , 56 | PRIMARY KEY (writer_name, bucket) 57 | ); 58 | 59 | -------------------------------------------------------------------------------- /docker-compose-example.yml: -------------------------------------------------------------------------------- 1 | version: '2.2' 2 | volumes: 3 | ldb: 4 | services: 5 | 6 | # sidecar exposes the reader API over HTTP 7 | sidecar: 8 | build: 9 | context: . 10 | dockerfile: Dockerfile 11 | ports: 12 | - "1331" 13 | restart: always 14 | entrypoint: 15 | - /usr/local/bin/ctlstore 16 | - sidecar 17 | volumes: 18 | - ldb:/var/spool/ctlstore 19 | 20 | # heartbeat sends a constant stream of mutations into the executive 21 | heartbeat: 22 | restart: always 23 | build: 24 | context: . 25 | dockerfile: Dockerfile 26 | entrypoint: 27 | - /usr/local/bin/ctlstore 28 | - heartbeat 29 | - -executive-url 30 | - executive:3000 31 | - -heartbeat-interval 32 | - 1s 33 | - -family-name 34 | - ctlstore 35 | - -table-name 36 | - heartbeats 37 | - -writer-name 38 | - heartbeat 39 | - -writer-secret 40 | - heartbeat 41 | 42 | # supervisor periodically snapshots ldb 43 | supervisor: 44 | restart: always 45 | build: 46 | context: . 47 | dockerfile: Dockerfile 48 | entrypoint: 49 | - /usr/local/bin/ctlstore 50 | - supervisor 51 | - -snapshot-url 52 | - "file:///snapshots/snapshot.db" 53 | - -snapshot-interval 54 | - "60s" 55 | - -reflector.ldb-path 56 | - /data/supervisor-ldb.db 57 | - -reflector.upstream-driver 58 | - mysql 59 | - -reflector.upstream-dsn 60 | - ctldb:ctldbpw@tcp(mysql:3306)/ctldb?collation=utf8mb4_unicode_ci 61 | volumes: 62 | - ldb:/var/spool/ctlstore 63 | 64 | # reflector pulls changes from mysql to a ldb 65 | reflector: 66 | restart: always 67 | build: 68 | context: . 69 | dockerfile: Dockerfile 70 | ports: 71 | - 9090:9090 72 | entrypoint: 73 | - /usr/local/bin/ctlstore 74 | - reflector 75 | - -ldb-path 76 | - /var/spool/ctlstore/ldb.db 77 | - -changelog-path 78 | - /var/spool/ctlstore/changelog 79 | - -changelog-size 80 | - "1000000" 81 | - -upstream-driver 82 | - mysql 83 | - -upstream-dsn 84 | - ctldb:ctldbpw@tcp(mysql:3306)/ctldb?collation=utf8mb4_unicode_ci 85 | - -ledger-latency.disable # no ECS in a docker-compose environment 86 | - -metrics-bind 87 | - 0.0.0.0:9090 88 | volumes: 89 | - ldb:/var/spool/ctlstore 90 | 91 | # executive guards the mysqldb 92 | executive: 93 | restart: always 94 | build: 95 | context: . 96 | dockerfile: Dockerfile 97 | entrypoint: 98 | - /usr/local/bin/ctlstore 99 | - executive 100 | - -bind 101 | - 0.0.0.0:3000 102 | - -ctldb 103 | - ctldb:ctldbpw@tcp(mysql:3306)/ctldb?collation=utf8mb4_unicode_ci 104 | 105 | # mysql represents the upstream db 106 | mysql: 107 | build: 108 | context: . 109 | dockerfile: Dockerfile-mysql 110 | restart: 111 | always 112 | environment: 113 | MYSQL_ROOT_PASSWORD: ctldbpw 114 | MYSQL_DATABASE: ctldb 115 | MYSQL_USER: ctldb 116 | MYSQL_PASSWORD: ctldbpw 117 | mem_limit: 536870912 118 | 119 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '2.2' 2 | services: 3 | mysql: 4 | image: mysql:5.6 5 | platform: linux/amd64 6 | restart: 7 | always 8 | ports: 9 | - "3306:3306" 10 | environment: 11 | MYSQL_ROOT_PASSWORD: ctldbpw 12 | MYSQL_DATABASE: ctldb 13 | MYSQL_USER: ctldb 14 | MYSQL_PASSWORD: ctldbpw 15 | mem_limit: 536870912 16 | -------------------------------------------------------------------------------- /examples/simple_client.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "reflect" 8 | "time" 9 | 10 | "github.com/segmentio/ctlstore" 11 | ) 12 | 13 | type MyTableRow struct { 14 | Key string `ctlstore:"mykey"` 15 | Val string `ctlstore:"myval"` 16 | } 17 | 18 | func main() { 19 | ctx := context.TODO() 20 | doneCh := make(chan struct{}) 21 | ldbReader, err := ctlstore.Reader() 22 | if err != nil { 23 | fmt.Printf("error: %+v\n", err) 24 | os.Exit(1) 25 | } 26 | 27 | go func() { 28 | defer close(doneCh) 29 | 30 | var lastVal MyTableRow 31 | var foundState bool 32 | var lastErr error 33 | 34 | for { 35 | var val MyTableRow 36 | found, err := ldbReader.GetRowByKey(ctx, &val, "myfamily", "mytable", "hello") 37 | if err != nil { 38 | if lastErr != err { 39 | fmt.Printf("New error reading LDB: %v\n", err) 40 | } 41 | } else { 42 | if found { 43 | if !reflect.DeepEqual(val, lastVal) { 44 | fmt.Printf("Got new data for key: %v\n", val) 45 | } 46 | foundState = false 47 | } 48 | 49 | if !found && !foundState { 50 | fmt.Println("Didn't find key, will let you know when I do!") 51 | foundState = true 52 | } 53 | 54 | lastVal = val 55 | } 56 | lastErr = err 57 | time.Sleep(1 * time.Second) 58 | } 59 | }() 60 | 61 | <-doneCh 62 | } 63 | -------------------------------------------------------------------------------- /ext_reader_test.go: -------------------------------------------------------------------------------- 1 | package ctlstore_test 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "testing" 7 | 8 | "github.com/google/go-cmp/cmp" 9 | "github.com/segmentio/ctlstore" 10 | ) 11 | 12 | func TestGetRowByKeyExternalPackage(t *testing.T) { 13 | type testKVStruct struct { 14 | Key string `ctlstore:"key"` 15 | Val string `ctlstore:"value"` 16 | } 17 | 18 | const initSQL = ` 19 | CREATE TABLE family1___table1 ( 20 | key VARCHAR PRIMARY KEY, 21 | value VARCHAR 22 | ); 23 | 24 | INSERT INTO family1___table1 VALUES('foo', 'bar'); 25 | ` 26 | ctx := context.Background() 27 | db, err := sql.Open("sqlite3", ":memory:") 28 | if err != nil { 29 | t.Fatalf("Unexpected error %+v", err) 30 | } 31 | _, err = db.Exec(initSQL) 32 | if err != nil { 33 | t.Fatalf("Unexpected error +%v", err) 34 | } 35 | 36 | reader := ctlstore.NewLDBReaderFromDB(db) 37 | gotOut := testKVStruct{} 38 | _, gotErr := reader.GetRowByKey( 39 | ctx, 40 | &gotOut, 41 | "family1", 42 | "table1", 43 | "foo", 44 | ) 45 | 46 | if gotErr != nil { 47 | t.Errorf("Unexpected error %+v", gotErr) 48 | } 49 | 50 | if diff := cmp.Diff(gotOut, testKVStruct{"foo", "bar"}); diff != "" { 51 | t.Errorf("GetRowByKey out param mismatch\n%s", diff) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/segmentio/ctlstore 2 | 3 | go 1.20 4 | 5 | require ( 6 | github.com/AlekSi/pointer v1.0.0 7 | github.com/aws/aws-sdk-go v1.37.8 8 | github.com/aws/aws-sdk-go-v2/config v1.18.40 9 | github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.84 10 | github.com/aws/aws-sdk-go-v2/service/s3 v1.38.5 11 | github.com/fsnotify/fsnotify v1.5.1 12 | github.com/go-sql-driver/mysql v1.4.1 13 | github.com/google/go-cmp v0.5.8 14 | github.com/google/uuid v1.1.2 15 | github.com/gorilla/mux v1.7.3 16 | github.com/julienschmidt/httprouter v1.2.0 17 | github.com/maxbrunsfeld/counterfeiter/v6 v6.4.1 18 | github.com/pkg/errors v0.9.1 19 | github.com/segmentio/cli v0.5.1 20 | github.com/segmentio/conf v1.1.0 21 | github.com/segmentio/errors-go v1.0.0 22 | github.com/segmentio/events/v2 v2.3.2 23 | github.com/segmentio/go-sqlite3 v1.14.22-segment 24 | github.com/segmentio/stats/v4 v4.6.2 25 | github.com/stretchr/testify v1.8.1 26 | golang.org/x/sync v0.6.0 27 | ) 28 | 29 | require ( 30 | github.com/aws/aws-sdk-go-v2 v1.21.0 // indirect 31 | github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.13 // indirect 32 | github.com/aws/aws-sdk-go-v2/credentials v1.13.38 // indirect 33 | github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.11 // indirect 34 | github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.41 // indirect 35 | github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.35 // indirect 36 | github.com/aws/aws-sdk-go-v2/internal/ini v1.3.42 // indirect 37 | github.com/aws/aws-sdk-go-v2/internal/v4a v1.1.4 // indirect 38 | github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.14 // indirect 39 | github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.36 // indirect 40 | github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.35 // indirect 41 | github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.15.4 // indirect 42 | github.com/aws/aws-sdk-go-v2/service/sso v1.14.0 // indirect 43 | github.com/aws/aws-sdk-go-v2/service/ssooidc v1.16.0 // indirect 44 | github.com/aws/aws-sdk-go-v2/service/sts v1.22.0 // indirect 45 | github.com/aws/smithy-go v1.14.2 // indirect 46 | github.com/davecgh/go-spew v1.1.1 // indirect 47 | github.com/jmespath/go-jmespath v0.4.0 // indirect 48 | github.com/mdlayher/genetlink v0.0.0-20190313224034-60417448a851 // indirect 49 | github.com/mdlayher/netlink v0.0.0-20190313131330-258ea9dff42c // indirect 50 | github.com/mdlayher/taskstats v0.0.0-20190313225729-7cbba52ee072 // indirect 51 | github.com/pmezard/go-difflib v1.0.0 // indirect 52 | github.com/segmentio/fasthash v0.0.0-20180216231524-a72b379d632e // indirect 53 | github.com/segmentio/go-snakecase v1.1.0 // indirect 54 | github.com/segmentio/objconv v1.0.1 // indirect 55 | golang.org/x/mod v0.15.0 // indirect 56 | golang.org/x/net v0.20.0 // indirect 57 | golang.org/x/sys v0.17.0 // indirect 58 | golang.org/x/tools v0.17.0 // indirect 59 | google.golang.org/appengine v1.6.7 // indirect 60 | gopkg.in/go-playground/assert.v1 v1.2.1 // indirect 61 | gopkg.in/go-playground/mold.v2 v2.2.0 // indirect 62 | gopkg.in/validator.v2 v2.0.0-20180514200540-135c24b11c19 // indirect 63 | gopkg.in/yaml.v2 v2.4.0 // indirect 64 | gopkg.in/yaml.v3 v3.0.1 // indirect 65 | ) 66 | -------------------------------------------------------------------------------- /initialize.go: -------------------------------------------------------------------------------- 1 | package ctlstore 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/segmentio/ctlstore/pkg/globalstats" 7 | "github.com/segmentio/stats/v4" 8 | ) 9 | 10 | type Config struct { 11 | // Stats specifies the config for reporting stats to the global 12 | // ctlstore stats namespace. 13 | // 14 | // By default, global stats are enabled with a set of sane defaults. 15 | Stats *globalstats.Config 16 | 17 | // LDBVersioning, if enabled, will instruct ctlstore to look for 18 | // LDBs inside of timestamp-delimited folders, and ctlstore will 19 | // hot-reload new LDBs as they appear. 20 | // 21 | // By default, this is disabled. 22 | LDBVersioning bool 23 | } 24 | 25 | var ldbVersioning bool 26 | 27 | func init() { 28 | // Enable globalstats by default. 29 | globalstats.Initialize(context.Background(), globalstats.Config{}) 30 | } 31 | 32 | // InitializeWithConfig sets up global state for thing including global 33 | // metrics globalstats data and possibly more as time goes on. 34 | func InitializeWithConfig(ctx context.Context, cfg Config) { 35 | if cfg.Stats != nil { 36 | // Initialize globalstats with the provided configuration: 37 | globalstats.Initialize(ctx, *cfg.Stats) 38 | } 39 | ldbVersioning = cfg.LDBVersioning 40 | } 41 | 42 | // Initialize sets up global state for thing including global 43 | // metrics globalstats data and possibly more as time goes on. 44 | // 45 | // Deprecated: see InitializeWithConfig 46 | func Initialize(ctx context.Context, appName string, statsHandler stats.Handler) { 47 | globalstats.Initialize(ctx, globalstats.Config{ 48 | AppName: appName, 49 | StatsHandler: statsHandler, 50 | }) 51 | } 52 | -------------------------------------------------------------------------------- /ldb.go: -------------------------------------------------------------------------------- 1 | package ctlstore 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "sync" 7 | 8 | "github.com/segmentio/ctlstore/pkg/ldb" 9 | "github.com/segmentio/ctlstore/pkg/sqlite" 10 | ) 11 | 12 | const ( 13 | DefaultCtlstorePath = "/var/spool/ctlstore/" 14 | DefaultChangelogFilename = "change.log" 15 | defaultLDBVersioningSubdir = "versioned" 16 | ) 17 | 18 | var ( 19 | globalLDBDirPath = DefaultCtlstorePath 20 | globalLDBVersioningDirPath = filepath.Join(DefaultCtlstorePath, defaultLDBVersioningSubdir) 21 | globalCLPath = filepath.Join(DefaultCtlstorePath, DefaultChangelogFilename) 22 | globalLDBReadOnly = true 23 | globalReader *LDBReader 24 | globalReaderMu sync.RWMutex 25 | ) 26 | 27 | func init() { 28 | envPath := os.Getenv("CTLSTORE_PATH") 29 | if envPath != "" { 30 | globalLDBDirPath = envPath 31 | globalLDBVersioningDirPath = filepath.Join(envPath, defaultLDBVersioningSubdir) 32 | globalCLPath = filepath.Join(envPath, DefaultChangelogFilename) 33 | } 34 | sqlite.InitDriver() 35 | } 36 | 37 | // ReaderForPath opens an LDB at the provided path and returns an LDBReader 38 | // instance pointed at that LDB. 39 | func ReaderForPath(path string) (*LDBReader, error) { 40 | return newLDBReader(path) 41 | } 42 | 43 | // Reader returns an LDBReader that can be used globally. 44 | func Reader() (*LDBReader, error) { 45 | globalReaderMu.RLock() 46 | defer globalReaderMu.RUnlock() 47 | 48 | if globalReader == nil { 49 | globalReaderMu.RUnlock() 50 | defer globalReaderMu.RLock() 51 | globalReaderMu.Lock() 52 | defer globalReaderMu.Unlock() 53 | 54 | if globalReader == nil { 55 | var reader *LDBReader 56 | var err error 57 | if ldbVersioning { 58 | reader, err = newVersionedLDBReader(globalLDBVersioningDirPath) 59 | } else { 60 | reader, err = newLDBReader(filepath.Join(globalLDBDirPath, ldb.DefaultLDBFilename)) 61 | } 62 | if err != nil { 63 | return nil, err 64 | } 65 | globalReader = reader 66 | } 67 | } 68 | 69 | return globalReader, nil 70 | } 71 | -------------------------------------------------------------------------------- /ldb_testing_test.go: -------------------------------------------------------------------------------- 1 | package ctlstore 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/segmentio/ctlstore/pkg/ldb" 11 | ) 12 | 13 | func TestLDBTestUtilCreateTableAndInsertRows(t *testing.T) { 14 | suite := []struct { 15 | desc string 16 | def LDBTestTableDef 17 | checkQuery string 18 | expectRows []map[string]interface{} 19 | }{ 20 | { 21 | desc: "Baseline No Rows", 22 | def: LDBTestTableDef{ 23 | Family: "family1", 24 | Name: "table1", 25 | Fields: [][]string{ 26 | {"string1", "string"}, 27 | {"integer1", "integer"}, 28 | }, 29 | KeyFields: []string{"string1"}, 30 | }, 31 | checkQuery: "SELECT string1, integer1 FROM family1___table1", 32 | }, 33 | { 34 | desc: "Insert Rows", 35 | def: LDBTestTableDef{ 36 | Family: "family1", 37 | Name: "table1", 38 | Fields: [][]string{ 39 | {"string1", "string"}, 40 | {"integer1", "integer"}, 41 | }, 42 | KeyFields: []string{"string1"}, 43 | Rows: [][]interface{}{ 44 | {"hello", 710}, 45 | }, 46 | }, 47 | expectRows: []map[string]interface{}{ 48 | { 49 | "string1": "hello", 50 | "integer1": 710, 51 | }, 52 | }, 53 | }, 54 | } 55 | 56 | for i, testCase := range suite { 57 | t.Run(fmt.Sprintf("[%d]%s", i, testCase.desc), func(t *testing.T) { 58 | db, err := sql.Open("sqlite3", ":memory:") 59 | if err != nil { 60 | t.Fatalf("Unexpected error: %+v", err) 61 | } 62 | 63 | err = ldb.EnsureLdbInitialized(context.Background(), db) 64 | if err != nil { 65 | t.Fatalf("Couldn't initialize SQLite db, error %v", err) 66 | } 67 | 68 | tu := &LDBTestUtil{ 69 | DB: db, 70 | T: t, 71 | } 72 | tu.CreateTable(testCase.def) 73 | 74 | _, err = db.Exec(testCase.checkQuery) 75 | if err != nil { 76 | t.Errorf("Query Failed\nQuery: %s\nError: %+v", testCase.checkQuery, err) 77 | } 78 | 79 | if err == nil && testCase.expectRows != nil { 80 | actualTable := fmt.Sprintf("%s___%s", testCase.def.Family, testCase.def.Name) 81 | for _, row := range testCase.expectRows { 82 | hunks := []string{ 83 | "SELECT COUNT(*) FROM", 84 | actualTable, 85 | "WHERE", 86 | } 87 | params := []interface{}{} 88 | 89 | clock := 0 90 | for name, val := range row { 91 | if clock != 0 { 92 | hunks = append(hunks, "AND") 93 | } 94 | hunks = append(hunks, name, "= ?") 95 | params = append(params, val) 96 | clock++ 97 | } 98 | 99 | qs := strings.Join(hunks, " ") 100 | qrow := db.QueryRow(qs, params...) 101 | cnt := 0 102 | err := qrow.Scan(&cnt) 103 | 104 | if err != nil { 105 | t.Errorf("Table query failed: %+v", err) 106 | } 107 | 108 | if cnt != 1 { 109 | t.Errorf("Didn't find row: %+v", row) 110 | } 111 | } 112 | } 113 | }) 114 | } 115 | } 116 | 117 | func TestLDBTestUtilReset(t *testing.T) { 118 | db, err := sql.Open("sqlite3", ":memory:") 119 | if err != nil { 120 | t.Fatalf("Unexpected error: %+v", err) 121 | } 122 | _, err = db.Exec(strings.Join([]string{ 123 | "CREATE TABLE family1___table1 (field1 VARCHAR);", 124 | "CREATE TABLE family1___table2 (field1 VARCHAR);", 125 | }, " ")) 126 | if err != nil { 127 | t.Fatalf("Unexpected error: %+v", err) 128 | } 129 | tu := &LDBTestUtil{ 130 | DB: db, 131 | T: t, 132 | } 133 | 134 | tu.Reset() 135 | 136 | for _, table := range []string{"table1", "table2"} { 137 | _, err = db.Exec("SELECT * FROM family1___" + table) 138 | if err == nil { 139 | t.Errorf("Expected to get an error querying family1.%s", table) 140 | } 141 | } 142 | 143 | } 144 | -------------------------------------------------------------------------------- /pkg/changelog/changelog_writer.go: -------------------------------------------------------------------------------- 1 | package changelog 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | "github.com/pkg/errors" 7 | "github.com/segmentio/events/v2" 8 | ) 9 | 10 | type ( 11 | // WriteLine writes a line to something 12 | WriteLine interface { 13 | WriteLine(string) error 14 | } 15 | ChangelogWriter struct { 16 | WriteLine WriteLine 17 | } 18 | ChangelogEntry struct { 19 | Seq int64 20 | Family string 21 | Table string 22 | Key []interface{} 23 | } 24 | ) 25 | 26 | func NewChangelogEntry(seq int64, family string, table string, key []interface{}) *ChangelogEntry { 27 | return &ChangelogEntry{Seq: seq, Family: family, Table: table, Key: key} 28 | } 29 | 30 | func (w *ChangelogWriter) WriteChange(e ChangelogEntry) error { 31 | structure := struct { 32 | Seq int64 `json:"seq"` 33 | Family string `json:"family"` 34 | Table string `json:"table"` 35 | Key []interface{} `json:"key"` 36 | }{ 37 | e.Seq, 38 | e.Family, 39 | e.Table, 40 | e.Key, 41 | } 42 | 43 | bytes, err := json.Marshal(structure) 44 | if err != nil { 45 | return errors.Wrap(err, "error marshalling json") 46 | } 47 | 48 | events.Debug("changelogWriter.WriteChange: %{family}s.%{table}s => %{key}v", 49 | e.Family, e.Table, e.Key) 50 | 51 | return w.WriteLine.WriteLine(string(bytes)) 52 | } 53 | -------------------------------------------------------------------------------- /pkg/changelog/changelog_writer_test.go: -------------------------------------------------------------------------------- 1 | package changelog 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | type clwWriteLineMock struct { 10 | Lines []string 11 | } 12 | 13 | func (w *clwWriteLineMock) WriteLine(s string) error { 14 | if w.Lines == nil { 15 | w.Lines = []string{} 16 | } 17 | w.Lines = append(w.Lines, s) 18 | return nil 19 | } 20 | 21 | func TestWriteChange(t *testing.T) { 22 | mock := &clwWriteLineMock{} 23 | clw := ChangelogWriter{WriteLine: mock} 24 | 25 | // Chose this number just to see if it serializes 54-bit integers 26 | // properly, because JavaScript is *INSANE* 27 | err := clw.WriteChange(ChangelogEntry{ 28 | Seq: 42, 29 | Family: "family1", 30 | Table: "table1", 31 | Key: []interface{}{18014398509481984, "foo"}, 32 | }) 33 | require.NoError(t, err) 34 | require.EqualValues(t, 1, len(mock.Lines)) 35 | require.Equal(t, `{"seq":42,"family":"family1","table":"table1","key":[18014398509481984,"foo"]}`, mock.Lines[0]) 36 | } 37 | -------------------------------------------------------------------------------- /pkg/cmd/ctlstore-cli/cmd/add_fields.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "net/http" 8 | 9 | "github.com/segmentio/cli" 10 | ) 11 | 12 | var cliAddFields = &cli.CommandFunc{ 13 | Help: "Add fields to an existing table", 14 | Desc: unindent(` 15 | This command makes an HTTP request to the executive service 16 | to add fields to an existing table 17 | 18 | Example: 19 | 20 | add-fields --family foo --table bar --field name:string 21 | `), 22 | Func: func(ctx context.Context, config struct { 23 | flagBase 24 | flagExecutive 25 | flagFamily 26 | flagTable 27 | flagFields 28 | }) (err error) { 29 | executive := config.MustExecutive() 30 | familyName := config.MustFamily() 31 | tableName := config.MustTable() 32 | fields := config.MustFields() 33 | 34 | var payload struct { 35 | Fields [][]string `json:"fields"` 36 | } 37 | for _, field := range fields { 38 | payload.Fields = append(payload.Fields, []string{field.name, field.typ}) 39 | } 40 | payloadBytes, err := json.Marshal(payload) 41 | if err != nil { 42 | bail("could not marshal payload: %s", err) 43 | } 44 | url := executive + "/families/" + familyName + "/tables/" + tableName 45 | req, err := http.NewRequest("PUT", url, bytes.NewReader(payloadBytes)) 46 | if err != nil { 47 | bail("could not create request: %s", err) 48 | } 49 | resp, err := httpClient.Do(req) 50 | if err != nil { 51 | bail("could not make request: %s", err) 52 | } 53 | defer resp.Body.Close() 54 | switch resp.StatusCode { 55 | case http.StatusOK: 56 | case http.StatusConflict: 57 | bail("One or more columns already exist") 58 | default: 59 | bailResponse(resp, "could not add fields") 60 | } 61 | return nil 62 | }, 63 | } 64 | -------------------------------------------------------------------------------- /pkg/cmd/ctlstore-cli/cmd/create_family.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "os" 8 | "path/filepath" 9 | 10 | "github.com/segmentio/cli" 11 | ) 12 | 13 | var cliCreateFamily = &cli.CommandFunc{ 14 | Help: "Create a new table family", 15 | Desc: unindent(fmt.Sprintf(` 16 | Create a new table family 17 | 18 | This command makes an HTTP request to the executive service 19 | to create a new table family. 20 | 21 | Example: 22 | 23 | %s create-family foo 24 | `, filepath.Base(os.Args[0]))), 25 | Func: func(ctx context.Context, config struct { 26 | flagBase 27 | flagExecutive 28 | }, args []string) (err error) { 29 | if len(args) != 1 { 30 | bail("Family required") 31 | } 32 | executive := config.MustExecutive() 33 | familyName := args[0] 34 | url := executive + "/families/" + familyName 35 | req, err := http.NewRequest("POST", url, nil) 36 | if err != nil { 37 | bail("could not create request: %s", err) 38 | } 39 | resp, err := httpClient.Do(req) 40 | if err != nil { 41 | bail("could not make request: %s", err) 42 | } 43 | defer resp.Body.Close() 44 | switch resp.StatusCode { 45 | case http.StatusOK: 46 | case http.StatusConflict: 47 | fmt.Println("Family already exists") 48 | default: 49 | bailResponse(resp, "could not create family '%s'", familyName) 50 | } 51 | return nil 52 | }, 53 | } 54 | -------------------------------------------------------------------------------- /pkg/cmd/ctlstore-cli/cmd/create_table.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "net/http" 9 | 10 | "github.com/segmentio/cli" 11 | ) 12 | 13 | var cliCreateTable = &cli.CommandFunc{ 14 | Help: "Create a new table", 15 | Desc: unindent(` 16 | Create a new table 17 | 18 | This command makes an HTTP request to the executive service 19 | to create a new table. 20 | 21 | Example: 22 | 23 | create-table --family foo --field name:string --field foo:integer --key-field name testtable 24 | 25 | Resulting schema: 26 | 27 | CREATE TABLE foo___testtable (name VARCHAR(191), foo INTEGER, PRIMARY KEY(name)); 28 | `), 29 | Func: func(ctx context.Context, config struct { 30 | flagBase 31 | flagExecutive 32 | flagFamily 33 | flagFields 34 | flagKeyFields 35 | }, args []string) error { 36 | executive := config.MustExecutive() 37 | familyName := config.MustFamily() 38 | fields := config.MustFields() 39 | keyFields := config.MustKeyFields() 40 | 41 | tableName := args[0] 42 | 43 | // todo: dedupe this declaration 44 | var payload struct { 45 | Fields [][]string `json:"fields"` 46 | KeyFields []string `json:"keyFields"` 47 | } 48 | for _, field := range fields { 49 | payload.Fields = append(payload.Fields, []string{field.name, field.typ}) 50 | } 51 | for _, keyField := range keyFields { 52 | payload.KeyFields = append(payload.KeyFields, keyField) 53 | } 54 | payloadBytes, err := json.Marshal(payload) 55 | if err != nil { 56 | bail("could not marshal payload: %s", err) 57 | } 58 | url := executive + "/families/" + familyName + "/tables/" + tableName 59 | req, err := http.NewRequest("POST", url, bytes.NewReader(payloadBytes)) 60 | if err != nil { 61 | bail("could not create request: %s", err) 62 | } 63 | resp, err := httpClient.Do(req) 64 | if err != nil { 65 | bail("could not make request: %s", err) 66 | } 67 | defer resp.Body.Close() 68 | switch resp.StatusCode { 69 | case http.StatusOK: 70 | case http.StatusConflict: 71 | fmt.Println("Table already exists") 72 | default: 73 | bailResponse(resp, "could not create table '%s'", tableName) 74 | } 75 | return nil 76 | }, 77 | } 78 | -------------------------------------------------------------------------------- /pkg/cmd/ctlstore-cli/cmd/flags.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "path" 5 | "strings" 6 | 7 | "github.com/segmentio/ctlstore" 8 | "github.com/segmentio/ctlstore/pkg/ldb" 9 | ) 10 | 11 | type flagBase struct { 12 | } 13 | 14 | type flagRowsPerMinute struct { 15 | RowsPerMinute int64 `flag:"rows-per-minute"` 16 | } 17 | 18 | type flagWriter struct { 19 | Writer string `flag:"writer"` 20 | } 21 | 22 | func (f flagWriter) MustWriter() string { 23 | if f.Writer == "" { 24 | bail("Writer required") 25 | } 26 | return f.Writer 27 | } 28 | 29 | type flagQuiet struct { 30 | Quiet bool `flag:"-q,--quiet"` 31 | } 32 | 33 | type flagExecutive struct { 34 | Executive string `flag:"-e,--executive" default:"ctlstore-executive.segment.local"` 35 | } 36 | 37 | func (f flagExecutive) MustExecutive() string { 38 | return normalizeURL(f.Executive) 39 | } 40 | 41 | type flagFamily struct { 42 | Family string `flag:"-f,--family"` 43 | } 44 | 45 | func (f flagFamily) MustFamily() string { 46 | if f.Family == "" { 47 | bail("Family required") 48 | } 49 | return f.Family 50 | } 51 | 52 | type flagTable struct { 53 | Table string `flag:"-t,--table"` 54 | } 55 | 56 | func (f flagTable) MustTable() string { 57 | if f.Table == "" { 58 | bail("Table required") 59 | } 60 | return f.Table 61 | } 62 | 63 | type flagSizeLimits struct { 64 | MaxSize int64 `flag:"--max-size" default:"104857600"` 65 | WarnSize int64 `flag:"--warn-size" default:"52428800"` 66 | } 67 | 68 | func (f flagSizeLimits) MustMaxSize() int64 { 69 | switch { 70 | case f.MaxSize < 0: 71 | bail("Max size cannot be negative") 72 | case f.MaxSize == 0: 73 | bail("Max size required") 74 | } 75 | return f.MaxSize 76 | } 77 | 78 | func (f flagSizeLimits) MustWarnSize() int64 { 79 | switch { 80 | case f.WarnSize < 0: 81 | bail("Warn size cannot be negative") 82 | case f.WarnSize == 0: 83 | bail("Warn size required") 84 | } 85 | return f.WarnSize 86 | } 87 | 88 | type flagFields struct { 89 | Fields []string `flag:"--field"` 90 | } 91 | 92 | func (f flagFields) MustFields() (res []field) { 93 | for _, val := range f.Fields { 94 | parts := strings.Split(val, ":") 95 | if len(parts) != 2 { 96 | bail("invalid field: %s", val) 97 | } 98 | res = append(res, field{ 99 | name: parts[0], 100 | typ: parts[1], 101 | }) 102 | } 103 | return 104 | } 105 | 106 | type flagKeyFields struct { 107 | KeyFields []string `flag:"--key-field"` 108 | } 109 | 110 | func (f flagKeyFields) MustKeyFields() []string { 111 | return f.KeyFields 112 | } 113 | 114 | type flagLDBPath struct { 115 | LDBPath string `flag:"-l,--ldb" default:"/var/spool/ctlstore/ldb.db"` 116 | } 117 | 118 | var ( 119 | defaultLDBPath = path.Join(ctlstore.DefaultCtlstorePath, ldb.DefaultLDBFilename) 120 | ) 121 | -------------------------------------------------------------------------------- /pkg/cmd/ctlstore-cli/cmd/read_keys.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | "encoding/hex" 6 | "fmt" 7 | "os" 8 | "regexp" 9 | "sort" 10 | "text/tabwriter" 11 | "time" 12 | 13 | "github.com/pkg/errors" 14 | "github.com/segmentio/cli" 15 | "github.com/segmentio/ctlstore" 16 | ) 17 | 18 | var cliReadKeys = &cli.CommandFunc{ 19 | Help: "read-keys [key1], [key2], ... [keyN]", 20 | Desc: unindent(` 21 | Reads a row from a local LDB 22 | 23 | This command reads one row given the specified family, table, and keys. 24 | The order of the keys must be specified in the order they appear in the 25 | schema. 26 | 27 | The output of this command will be a table of columnName -> columnValue, 28 | sorted by the column names. 29 | `), 30 | Func: func(ctx context.Context, config struct { 31 | flagBase 32 | flagQuiet 33 | flagLDBPath 34 | flagFamily 35 | flagTable 36 | }, args []string) error { 37 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 38 | defer cancel() 39 | 40 | ldbPath := config.LDBPath 41 | familyName := config.MustFamily() 42 | tableName := config.MustTable() 43 | quiet := config.Quiet 44 | 45 | keys, err := getKeys(args) 46 | if err != nil { 47 | return err 48 | } 49 | reader, err := ctlstore.ReaderForPath(ldbPath) 50 | if err != nil { 51 | return errors.Wrap(err, "ldb reader for path") 52 | } 53 | defer reader.Close() 54 | resMap := make(map[string]interface{}) 55 | found, err := reader.GetRowByKey(ctx, resMap, familyName, tableName, keys...) 56 | if err != nil { 57 | bail("Could not read row: %s", err) 58 | } 59 | if !found { 60 | bail("Not found") 61 | } 62 | w := tabwriter.NewWriter(os.Stdout, 0, 0, 1, ' ', tabwriter.TabIndent) 63 | if !quiet { 64 | fmt.Fprintln(w, "COLUMN\tVALUE") 65 | fmt.Fprintln(w, "------\t-----") 66 | } 67 | var sortedResultKeys []string 68 | for key := range resMap { 69 | sortedResultKeys = append(sortedResultKeys, key) 70 | } 71 | sort.Strings(sortedResultKeys) 72 | for _, key := range sortedResultKeys { 73 | val := resMap[key] 74 | valFormat := "%v" 75 | switch val.(type) { 76 | case []byte: 77 | valFormat = "0x%x" 78 | } 79 | fmt.Fprintln(w, fmt.Sprintf("%s\t"+valFormat, key, val)) 80 | } 81 | w.Flush() 82 | return nil 83 | }, 84 | } 85 | 86 | var hexKeyRE = regexp.MustCompile(`^0x([0-9a-fA-F]*)$`) 87 | 88 | // getKeys converts each key input into a type that can be passed 89 | // into the reader.GetRowByKey method. Specifically, it checks to 90 | // see if it's a binary literal, and handles that case explicitly. 91 | // The other types just get passed through as regular strings. 92 | func getKeys(args []string) (res []interface{}, err error) { 93 | for _, arg := range args { 94 | key, err := parseKey(arg) 95 | if err != nil { 96 | return res, err 97 | } 98 | res = append(res, key) 99 | } 100 | return res, nil 101 | } 102 | 103 | func parseKey(key string) (interface{}, error) { 104 | parts := hexKeyRE.FindStringSubmatch(key) 105 | if len(parts) != 2 { 106 | // it's not a hex literal, just return the key itself 107 | return key, nil 108 | } 109 | hex, err := hex.DecodeString(parts[1]) 110 | if err != nil { 111 | return nil, errors.Errorf("could not parse '%s' as hex", parts[1]) 112 | } 113 | return hex, nil 114 | } 115 | -------------------------------------------------------------------------------- /pkg/cmd/ctlstore-cli/cmd/read_keys_test.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "encoding/hex" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func TestGetKeys(t *testing.T) { 11 | decodeHex := func(h string) []byte { 12 | res, err := hex.DecodeString(h) 13 | require.NoError(t, err) 14 | return res 15 | } 16 | for _, test := range []struct { 17 | name string 18 | input []string 19 | output []interface{} 20 | err error 21 | }{ 22 | { 23 | name: "noargs", 24 | input: []string{}, 25 | output: nil, 26 | err: nil, 27 | }, 28 | { 29 | name: "string arg", 30 | input: []string{`foo`}, 31 | output: []interface{}{"foo"}, 32 | err: nil, 33 | }, 34 | { 35 | name: "hex arg", 36 | input: []string{"0xabcd"}, 37 | output: []interface{}{decodeHex("abcd")}, 38 | err: nil, 39 | }, 40 | { 41 | name: "hex arg mixed case", 42 | input: []string{"0xABcD"}, 43 | output: []interface{}{decodeHex("abcd")}, 44 | err: nil, 45 | }, 46 | { 47 | name: "pass through of non-hex keys", 48 | input: []string{"0xzz", "0xaz"}, 49 | output: []interface{}{"0xzz", "0xaz"}, 50 | err: nil, 51 | }, 52 | } { 53 | t.Run(test.name, func(t *testing.T) { 54 | out, err := getKeys(test.input) 55 | if test.err != nil { 56 | require.EqualError(t, err, test.err.Error()) 57 | } else { 58 | require.NoError(t, err) 59 | } 60 | require.EqualValues(t, test.output, out) 61 | }) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /pkg/cmd/ctlstore-cli/cmd/read_seq.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/segmentio/cli" 9 | "github.com/segmentio/ctlstore" 10 | ) 11 | 12 | var cliReadSeq = &cli.CommandFunc{ 13 | Help: "Read last sequenece from the LDB", 14 | Func: func(ctx context.Context, config struct { 15 | flagLDBPath 16 | }) error { 17 | ldbPath := config.LDBPath 18 | reader, err := ctlstore.ReaderForPath(ldbPath) 19 | if err != nil { 20 | bail("Could not get reader: %s", err) 21 | } 22 | defer reader.Close() 23 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) 24 | defer cancel() 25 | 26 | seq, err := reader.GetLastSequence(ctx) 27 | if err != nil { 28 | bail("Could not get sequence: %s", err) 29 | } 30 | fmt.Println(seq) 31 | return nil 32 | }, 33 | } 34 | -------------------------------------------------------------------------------- /pkg/cmd/ctlstore-cli/cmd/root.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/segmentio/cli" 8 | ) 9 | 10 | func Execute() { 11 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 12 | defer cancel() 13 | cli.ExecContext(ctx, cli.CommandSet{ 14 | "table-limits": cliTableLimits, 15 | "create-table": cliCreateTable, 16 | "create-family": cliCreateFamily, 17 | "add-fields": cliAddFields, 18 | "read-keys": cliReadKeys, 19 | "read-seq": cliReadSeq, 20 | "writer-limits": cliWriterLimits, 21 | }) 22 | } 23 | -------------------------------------------------------------------------------- /pkg/cmd/ctlstore-cli/cmd/table_limits.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "os" 9 | "text/tabwriter" 10 | 11 | "github.com/segmentio/cli" 12 | "github.com/segmentio/ctlstore/pkg/limits" 13 | "github.com/segmentio/ctlstore/pkg/utils" 14 | ) 15 | 16 | var cliTableLimits = cli.CommandSet{ 17 | "read": &cli.CommandFunc{ 18 | Help: "Read all table limits", 19 | Func: func(ctx context.Context, config struct { 20 | flagBase 21 | flagExecutive 22 | }) error { 23 | url := config.MustExecutive() + "/limits/tables" 24 | req, err := http.NewRequest(http.MethodGet, url, nil) 25 | if err != nil { 26 | bail("could not create request: %s", err) 27 | } 28 | resp, err := httpClient.Do(req) 29 | if err != nil { 30 | bail("could not make request: %s", err) 31 | } 32 | defer resp.Body.Close() 33 | if resp.StatusCode != http.StatusOK { 34 | bailResponse(resp, "could not read limits") 35 | } 36 | var tsl limits.TableSizeLimits 37 | if err := json.NewDecoder(resp.Body).Decode(&tsl); err != nil { 38 | bail("could not decode response: %s", err) 39 | } 40 | fmt.Printf("warn: %d bytes\n", tsl.Global.WarnSize) 41 | fmt.Printf("max : %d bytes\n", tsl.Global.MaxSize) 42 | if len(tsl.Tables) == 0 { 43 | return nil 44 | } 45 | fmt.Println() 46 | w := tabwriter.NewWriter(os.Stdout, 0, 0, 1, ' ', tabwriter.TabIndent) 47 | fmt.Fprintln(w, "FAMILY\tTABLE\tWARN\tMAX") 48 | fmt.Fprintln(w, "------\t-----\t----\t---") 49 | for _, t := range tsl.Tables { 50 | fmt.Fprintf(w, "%s\t%s\t%d\t%d\n", t.Family, t.Table, t.WarnSize, t.MaxSize) 51 | } 52 | return w.Flush() 53 | }, 54 | }, 55 | "update": &cli.CommandFunc{ 56 | Help: "Update a table limit", 57 | Func: func(ctx context.Context, config struct { 58 | flagBase 59 | flagExecutive 60 | flagFamily 61 | flagTable 62 | flagSizeLimits 63 | }) error { 64 | executive := config.MustExecutive() 65 | familyName := config.MustFamily() 66 | tableName := config.MustTable() 67 | maxSize := config.MustMaxSize() 68 | warnSize := config.MustWarnSize() 69 | if warnSize > maxSize { 70 | bail("warnSize must be <= maxSize") 71 | } 72 | url := executive + "/limits/tables/" + familyName + "/" + tableName 73 | payload := limits.SizeLimits{ 74 | WarnSize: warnSize, 75 | MaxSize: maxSize, 76 | } 77 | req, err := http.NewRequest(http.MethodPost, url, utils.NewJsonReader(payload)) 78 | if err != nil { 79 | bail("could not build request: %s", err) 80 | } 81 | resp, err := httpClient.Do(req) 82 | if err != nil { 83 | bail("could not make request: %s", err) 84 | } 85 | if resp.StatusCode != http.StatusOK { 86 | bailResponse(resp, "could not update table limit") 87 | } 88 | return nil 89 | }, 90 | }, 91 | "delete": &cli.CommandFunc{ 92 | Help: "Delete a table limit", 93 | Func: func(ctx context.Context, config struct { 94 | flagBase 95 | flagExecutive 96 | flagFamily 97 | flagTable 98 | }) error { 99 | executive := config.MustExecutive() 100 | familyName := config.MustFamily() 101 | tableName := config.MustTable() 102 | url := executive + "/limits/tables/" + familyName + "/" + tableName 103 | req, err := http.NewRequest(http.MethodDelete, url, nil) 104 | if err != nil { 105 | bail("could not build request: %s", err) 106 | } 107 | resp, err := httpClient.Do(req) 108 | if err != nil { 109 | bail("could not make request: %s", err) 110 | } 111 | if resp.StatusCode != http.StatusOK { 112 | bailResponse(resp, "could not delete table limit") 113 | } 114 | return nil 115 | }, 116 | }, 117 | } 118 | -------------------------------------------------------------------------------- /pkg/cmd/ctlstore-cli/cmd/table_limits_test.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | 7 | "github.com/segmentio/ctlstore/pkg/limits" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestDeserReadTableLimits(t *testing.T) { 12 | input := `{"global":{"max-size":104857600,"warn-size":52428800},"tables":[{"max-size":1024,"warn-size":528,"family":"loadfamily","table":"loadtable"}]}` 13 | var tsl limits.TableSizeLimits 14 | err := json.Unmarshal([]byte(input), &tsl) 15 | require.NoError(t, err) 16 | 17 | require.EqualValues(t, 104857600, tsl.Global.MaxSize) 18 | require.EqualValues(t, 52428800, tsl.Global.WarnSize) 19 | require.Len(t, tsl.Tables, 1) 20 | 21 | table := tsl.Tables[0] 22 | require.EqualValues(t, "loadfamily", table.Family) 23 | require.EqualValues(t, "loadtable", table.Table) 24 | require.EqualValues(t, 1024, table.MaxSize) 25 | require.EqualValues(t, 528, table.WarnSize) 26 | } 27 | -------------------------------------------------------------------------------- /pkg/cmd/ctlstore-cli/cmd/utils.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io/ioutil" 7 | "net/http" 8 | "net/url" 9 | "os" 10 | "strings" 11 | ) 12 | 13 | // field represents a table field. it has a name and a type. 14 | type field struct { 15 | name string 16 | typ string 17 | } 18 | 19 | var ( 20 | httpClient = &http.Client{} 21 | ) 22 | 23 | // bailResponse is similar to bail, but includes details about a failed 24 | // http.Response. 25 | func bailResponse(response *http.Response, msg string, args ...interface{}) { 26 | msg = fmt.Sprintf(msg, args...) 27 | // ok to ignore error here 28 | b, _ := ioutil.ReadAll(response.Body) 29 | respMsg := fmt.Sprintf("server returned [%d]: %s", response.StatusCode, b) 30 | fmt.Fprintln(os.Stderr, fmt.Sprintf("%s: %s", msg, respMsg)) 31 | os.Exit(1) 32 | } 33 | 34 | // bail prints a message to stderr and exits with status=1 35 | func bail(msg string, args ...interface{}) { 36 | msg = fmt.Sprintf(msg, args...) 37 | fmt.Fprintln(os.Stderr, msg) 38 | os.Exit(1) 39 | } 40 | 41 | func normalizeURL(val string) string { 42 | if !strings.HasPrefix(val, "http://") && !strings.HasPrefix(val, "https://") { 43 | val = "http://" + val 44 | } 45 | if _, err := url.Parse(val); err != nil { 46 | bail("invalid url: %v", err) 47 | } 48 | return val 49 | } 50 | 51 | // unindent formats long help text before it's printed to the console. 52 | // it's helpful to indent multiline strings to make it look nice in the 53 | // code, but you don't want those indents to make their way to the 54 | // console output. 55 | func unindent(str string) string { 56 | str = strings.TrimSpace(str) 57 | out := new(bytes.Buffer) 58 | for _, line := range strings.Split(str, "\n") { 59 | out.WriteString(strings.TrimSpace(line) + "\n") 60 | } 61 | return out.String() 62 | } 63 | -------------------------------------------------------------------------------- /pkg/cmd/ctlstore-cli/cmd/writer_limits.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "io/ioutil" 8 | "net/http" 9 | "os" 10 | "text/tabwriter" 11 | "time" 12 | 13 | "github.com/segmentio/cli" 14 | "github.com/segmentio/ctlstore/pkg/limits" 15 | "github.com/segmentio/ctlstore/pkg/utils" 16 | ) 17 | 18 | var cliWriterLimits = &cli.CommandSet{ 19 | "read": &cli.CommandFunc{ 20 | Help: "Read all writer limits", 21 | Func: func(ctx context.Context, config struct { 22 | flagBase 23 | flagExecutive 24 | }) error { 25 | executive := config.MustExecutive() 26 | url := executive + "/limits/writers" 27 | req, err := http.NewRequest(http.MethodGet, url, nil) 28 | if err != nil { 29 | bail("could not create request: %s", err) 30 | } 31 | resp, err := httpClient.Do(req) 32 | if err != nil { 33 | bail("could not make request: %s", err) 34 | } 35 | defer resp.Body.Close() 36 | if resp.StatusCode != http.StatusOK { 37 | bailResponse(resp, "could not read limits") 38 | } 39 | var wrl limits.WriterRateLimits 40 | b, err := ioutil.ReadAll(resp.Body) 41 | if err != nil { 42 | bail("could not read response: %s", err) 43 | } 44 | if err := json.Unmarshal(b, &wrl); err != nil { 45 | bail("could not decode response: %s", err) 46 | } 47 | fmt.Println("default:", wrl.Global) 48 | if len(wrl.Writers) == 0 { 49 | return nil 50 | } 51 | fmt.Println() 52 | w := tabwriter.NewWriter(os.Stdout, 0, 0, 1, ' ', tabwriter.TabIndent) 53 | fmt.Fprintln(w, "WRITER\tLIMIT") 54 | fmt.Fprintln(w, "------\t-----") 55 | for _, t := range wrl.Writers { 56 | fmt.Fprintf(w, "%s\t%s\n", t.Writer, t.RateLimit) 57 | } 58 | return w.Flush() 59 | }, 60 | }, 61 | "update": &cli.CommandFunc{ 62 | Help: "Add or update a writer limit", 63 | Func: func(ctx context.Context, config struct { 64 | flagBase 65 | flagExecutive 66 | flagRowsPerMinute 67 | flagWriter 68 | }) error { 69 | executive := config.MustExecutive() 70 | writer := config.MustWriter() 71 | rowsPerMinute := config.RowsPerMinute 72 | url := executive + "/limits/writers/" + writer 73 | payload := limits.RateLimit{ 74 | Amount: rowsPerMinute, 75 | Period: time.Minute, 76 | } 77 | req, err := http.NewRequest(http.MethodPost, url, utils.NewJsonReader(payload)) 78 | if err != nil { 79 | bail("could not build request: %s", err) 80 | } 81 | resp, err := httpClient.Do(req) 82 | if err != nil { 83 | bail("could not make request: %s", err) 84 | } 85 | if resp.StatusCode != http.StatusOK { 86 | bailResponse(resp, "could not update writer limit") 87 | } 88 | return nil 89 | }, 90 | }, 91 | "delete": &cli.CommandFunc{ 92 | Help: "Delete a writer limit", 93 | Func: func(ctx context.Context, config struct { 94 | flagBase 95 | flagExecutive 96 | flagWriter 97 | }) error { 98 | executive := config.MustExecutive() 99 | writer := config.MustWriter() 100 | url := executive + "/limits/writers/" + writer 101 | req, err := http.NewRequest(http.MethodDelete, url, nil) 102 | if err != nil { 103 | bail("could not build request: %s", err) 104 | } 105 | resp, err := httpClient.Do(req) 106 | if err != nil { 107 | bail("could not make request: %s", err) 108 | } 109 | if resp.StatusCode != http.StatusOK { 110 | bailResponse(resp, "could not delete writer limit") 111 | } 112 | return nil 113 | }, 114 | }, 115 | } 116 | -------------------------------------------------------------------------------- /pkg/cmd/ctlstore-cli/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "github.com/segmentio/ctlstore/pkg/cmd/ctlstore-cli/cmd" 4 | 5 | func main() { 6 | cmd.Execute() 7 | } 8 | -------------------------------------------------------------------------------- /pkg/ctldb/ctldb.go: -------------------------------------------------------------------------------- 1 | package ctldb 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "strings" 7 | ) 8 | 9 | const LimiterDBSchemaUp = ` 10 | CREATE TABLE max_table_sizes ( 11 | family_name VARCHAR(30) NOT NULL, /* limit pulled from validate.go */ 12 | table_name VARCHAR(50) NOT NULL, /* limit pulled from validate.go */ 13 | warn_size_bytes BIGINT NOT NULL DEFAULT 0, 14 | max_size_bytes BIGINT NOT NULL DEFAULT 0, 15 | PRIMARY KEY (family_name, table_name) 16 | ); 17 | 18 | CREATE TABLE max_writer_rates ( 19 | writer_name VARCHAR(50) NOT NULL, /* limit pulled from validate.go */ 20 | max_rows_per_minute BIGINT NOT NULL , 21 | PRIMARY KEY (writer_name) 22 | ); 23 | 24 | CREATE TABLE writer_usage ( 25 | writer_name VARCHAR(50) NOT NULL, /* limit pulled from validate.go */ 26 | bucket BIGINT NOT NULL, 27 | amount BIGINT NOT NULL , 28 | PRIMARY KEY (writer_name, bucket) 29 | ); ` 30 | 31 | var CtlDBSchemaByDriver = map[string]string{ 32 | "mysql": ` 33 | 34 | ALTER DATABASE CHARACTER SET = utf8mb4 COLLATE = utf8mb4_unicode_ci; 35 | 36 | CREATE TABLE families ( 37 | id INTEGER AUTO_INCREMENT PRIMARY KEY, 38 | name VARCHAR(191) NOT NULL, 39 | UNIQUE KEY name (name) 40 | ); 41 | 42 | CREATE TABLE mutators ( 43 | writer VARCHAR(191) NOT NULL PRIMARY KEY, 44 | secret VARCHAR(255) NOT NULL, 45 | cookie BLOB(1024) NOT NULL, 46 | clock BIGINT NOT NULL DEFAULT 0 47 | ); 48 | 49 | CREATE TABLE ctlstore_dml_ledger ( 50 | seq INTEGER AUTO_INCREMENT PRIMARY KEY, 51 | leader_ts DATETIME DEFAULT CURRENT_TIMESTAMP, 52 | statement MEDIUMTEXT NOT NULL 53 | ); 54 | 55 | CREATE TABLE locks ( 56 | id VARCHAR(191) NOT NULL PRIMARY KEY, 57 | clock BIGINT NOT NULL DEFAULT 0 58 | ); 59 | 60 | INSERT INTO locks VALUES('ledger', 0); 61 | 62 | ` + LimiterDBSchemaUp, 63 | "sqlite3": ` 64 | 65 | CREATE TABLE families ( 66 | id INTEGER PRIMARY KEY AUTOINCREMENT, 67 | name VARCHAR(191) NOT NULL UNIQUE 68 | ); 69 | 70 | CREATE TABLE mutators ( 71 | writer VARCHAR(191) NOT NULL PRIMARY KEY, 72 | secret VARCHAR(255), 73 | cookie BLOB(1024) NOT NULL, 74 | clock INTEGER NOT NULL DEFAULT 0 75 | ); 76 | 77 | CREATE TABLE ctlstore_dml_ledger ( 78 | seq INTEGER PRIMARY KEY AUTOINCREMENT, 79 | leader_ts DATETIME DEFAULT CURRENT_TIMESTAMP, 80 | statement TEXT NOT NULL 81 | ); 82 | 83 | CREATE TABLE locks ( 84 | id VARCHAR(191) NOT NULL PRIMARY KEY, 85 | clock INTEGER NOT NULL DEFAULT 0 86 | ); 87 | 88 | INSERT INTO locks VALUES('ledger', 0); 89 | ` + LimiterDBSchemaUp, 90 | } 91 | 92 | func InitializeCtlDB(db *sql.DB, driverFunc func(driver driver.Driver) (name string)) error { 93 | driverName := driverFunc(db.Driver()) 94 | schema := CtlDBSchemaByDriver[driverName] 95 | statements := strings.Split(schema, ";") 96 | 97 | for _, statement := range statements { 98 | tsql := strings.TrimSpace(statement) 99 | if tsql == "" { 100 | continue 101 | } 102 | _, err := db.Exec(tsql) 103 | if err != nil { 104 | return err 105 | } 106 | } 107 | 108 | return nil 109 | } 110 | -------------------------------------------------------------------------------- /pkg/ctldb/ctldb_test_helpers.go: -------------------------------------------------------------------------------- 1 | package ctldb 2 | 3 | import "testing" 4 | 5 | // This configuration comes from the docker-compose.yml file 6 | const testCtlDBRawDSN = "ctldb:ctldbpw@tcp(localhost:3306)/ctldb" 7 | 8 | func GetTestCtlDBDSN(t *testing.T) string { 9 | dsn, err := SetCtldbDSNParameters(testCtlDBRawDSN) 10 | if err != nil { 11 | if t == nil { 12 | panic(err) 13 | } 14 | t.Fatal(err) 15 | } 16 | return dsn 17 | } 18 | -------------------------------------------------------------------------------- /pkg/ctldb/dsn_parameters.go: -------------------------------------------------------------------------------- 1 | package ctldb 2 | 3 | import "net/url" 4 | 5 | func SetCtldbDSNParameters(dsn string) (string, error) { 6 | var err error 7 | 8 | parameters := map[string]string{ 9 | "collation": "utf8mb4_unicode_ci", 10 | "timeout": "5s", 11 | "sql_mode": "'NO_BACKSLASH_ESCAPES,ANSI_QUOTES'", 12 | } 13 | for name, value := range parameters { 14 | dsn, err = AddParameterToDSN(dsn, name, value) 15 | if err != nil { 16 | return "", err 17 | } 18 | } 19 | 20 | return dsn, nil 21 | } 22 | 23 | func AddParameterToDSN(dsn string, key string, value string) (string, error) { 24 | parsed, err := url.Parse(dsn) 25 | if err != nil { 26 | return "", err 27 | } 28 | q := parsed.Query() 29 | q.Add(key, value) 30 | parsed.RawQuery = q.Encode() 31 | return parsed.String(), nil 32 | } 33 | -------------------------------------------------------------------------------- /pkg/errs/errs.go: -------------------------------------------------------------------------------- 1 | package errs 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/segmentio/errors-go" 8 | "github.com/segmentio/stats/v4" 9 | ) 10 | 11 | const ( 12 | defaultErrName = "errors" 13 | ) 14 | 15 | const ( 16 | // these error types are handy when using errors-go 17 | ErrTypeTemporary = "Temporary" 18 | ErrTypePermanent = "Permanent" 19 | ) 20 | 21 | func IsCanceled(err error) bool { 22 | return err != nil && errors.Cause(err) == context.Canceled 23 | } 24 | 25 | // IncrDefault increments the default error metric 26 | func IncrDefault(tags ...stats.Tag) { 27 | Incr(defaultErrName, tags...) 28 | } 29 | 30 | // Incr increments an error metric, along with the default error metric 31 | func Incr(name string, tags ...stats.Tag) { 32 | stats.Incr(name, tags...) 33 | if name == defaultErrName { 34 | // don't increment the default error twice 35 | return 36 | } 37 | // add a tag to indicate the name of the original error. We can then 38 | // view that tag in datadog to figure out what the error was. 39 | newTags := make([]stats.Tag, len(tags), len(tags)+1) 40 | copy(newTags, tags) 41 | newTags = append(newTags, stats.T("error", name)) 42 | stats.Incr(defaultErrName, newTags...) 43 | } 44 | 45 | // These are here because there's a need for a set of errors that have roughly 46 | // REST/HTTP compatibility, but aren't directly coupled to that interface. Lower 47 | // layers of the system can generate these errors while still making sense in 48 | // any context. 49 | type baseError struct { 50 | Err string 51 | } 52 | 53 | type ConflictError baseError 54 | 55 | func (e ConflictError) Error() string { 56 | return e.Err 57 | } 58 | 59 | type BadRequestError baseError 60 | 61 | func (e BadRequestError) Error() string { 62 | return e.Err 63 | } 64 | 65 | func BadRequest(format string, args ...interface{}) error { 66 | return &BadRequestError{ 67 | Err: fmt.Sprintf(format, args...), 68 | } 69 | } 70 | 71 | type NotFoundError baseError 72 | 73 | func (e NotFoundError) Error() string { 74 | return e.Err 75 | } 76 | 77 | func NotFound(format string, args ...interface{}) error { 78 | return &NotFoundError{ 79 | Err: fmt.Sprintf(format, args...), 80 | } 81 | } 82 | 83 | type PayloadTooLargeError baseError 84 | 85 | func (e PayloadTooLargeError) Error() string { 86 | return e.Err 87 | } 88 | 89 | type RateLimitExceededErr baseError 90 | 91 | func (e RateLimitExceededErr) Error() string { 92 | return e.Err 93 | } 94 | 95 | type InsufficientStorageErr baseError 96 | 97 | func (e InsufficientStorageErr) Error() string { 98 | return e.Err 99 | } 100 | -------------------------------------------------------------------------------- /pkg/event/entry.go: -------------------------------------------------------------------------------- 1 | package event 2 | 3 | // entry represents a single row in the changelog 4 | // e.g. 5 | // {"seq":1,"family":"fam","table":"foo","key":[{"name":"id","type":"int","value":1}]} 6 | type entry struct { 7 | Seq int64 `json:"seq"` 8 | Family string `json:"family"` 9 | Table string `json:"table"` 10 | Key []Key `json:"key"` 11 | } 12 | 13 | // event converts the entry into an event for the iterator to return 14 | func (e entry) event() Event { 15 | return Event{ 16 | Sequence: e.Seq, 17 | RowUpdate: RowUpdate{ 18 | FamilyName: e.Family, 19 | TableName: e.Table, 20 | Keys: e.Key, 21 | }, 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /pkg/event/event.go: -------------------------------------------------------------------------------- 1 | package event 2 | 3 | // Event is the type that the Iterator produces 4 | type Event struct { 5 | Sequence int64 6 | RowUpdate RowUpdate 7 | } 8 | 9 | // RowUpdate represents a single row update 10 | type RowUpdate struct { 11 | FamilyName string `json:"family"` 12 | TableName string `json:"table"` 13 | Keys []Key `json:"keys"` 14 | } 15 | 16 | // Key represents a single primary key column value and metadata 17 | type Key struct { 18 | Name string `json:"name"` 19 | Type string `json:"type"` 20 | Value interface{} `json:"value"` 21 | } 22 | -------------------------------------------------------------------------------- /pkg/event/fake_changelog.go: -------------------------------------------------------------------------------- 1 | package event 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/pkg/errors" 7 | ) 8 | 9 | type ( 10 | fakeChangelog struct { 11 | startErr error 12 | ers []eventErr 13 | pos int 14 | } 15 | ) 16 | 17 | func (c *fakeChangelog) start(ctx context.Context) error { 18 | return c.startErr 19 | } 20 | 21 | func (c *fakeChangelog) next(ctx context.Context) (Event, error) { 22 | if c.pos >= len(c.ers) { 23 | return Event{}, errors.New("exhausted changelog set") 24 | } 25 | ee := c.ers[c.pos] 26 | c.pos++ 27 | return ee.event, ee.err 28 | } 29 | -------------------------------------------------------------------------------- /pkg/event/fake_log_writer.go: -------------------------------------------------------------------------------- 1 | package event 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "encoding/json" 7 | "os" 8 | "sync/atomic" 9 | "time" 10 | 11 | "github.com/segmentio/errors-go" 12 | "github.com/segmentio/events/v2" 13 | ) 14 | 15 | type fakeLogWriter struct { 16 | path string 17 | family string 18 | table string 19 | delay time.Duration 20 | rotateAfter int 21 | rotateAfterBytes int 22 | seq int64 23 | rotations int64 24 | } 25 | 26 | // writes N events to the log 27 | func (w *fakeLogWriter) writeN(ctx context.Context, n int) error { 28 | f, err := os.Create(w.path) 29 | if err != nil { 30 | return errors.Wrap(err, "create file") 31 | } 32 | defer func() { 33 | events.Debug("Done writing %{num}d events", n) 34 | if err == nil { 35 | err = f.Close() 36 | } 37 | }() 38 | bw := bufio.NewWriter(f) 39 | total := 0 40 | these := 0 41 | for total < n && ctx.Err() == nil { 42 | total++ 43 | these++ 44 | entry := entry{ 45 | Seq: atomic.LoadInt64(&w.seq), 46 | Family: w.family, 47 | Table: w.table, 48 | Key: []Key{ 49 | { 50 | Name: "id-column", 51 | Type: "int", 52 | Value: 42, 53 | }, 54 | }, 55 | } 56 | atomic.AddInt64(&w.seq, 1) 57 | err := json.NewEncoder(bw).Encode(entry) 58 | if err != nil { 59 | return errors.Wrap(err, "write event") 60 | } 61 | if err := bw.Flush(); err != nil { 62 | return errors.Wrap(err, "flush") 63 | } 64 | time.Sleep(w.delay) 65 | 66 | doRotate := false 67 | if w.rotateAfterBytes > 0 { 68 | info, err := os.Stat(w.path) 69 | if err != nil { 70 | return errors.Wrap(err, "stat path") 71 | } 72 | // fmt.Println(info.Size(), w.rotateAfterBytes) 73 | if info.Size() > int64(w.rotateAfterBytes) { 74 | events.Log("Rotation required (file size is %{bytes}d seq=%{seq}d)", info.Size(), atomic.LoadInt64(&w.seq)) 75 | doRotate = true 76 | } 77 | } 78 | if w.rotateAfter > 0 && these >= w.rotateAfter { 79 | doRotate = true 80 | } 81 | 82 | if doRotate { 83 | events.Debug("Rotating log file..") 84 | these = 0 85 | if err := f.Close(); err != nil { 86 | return errors.Wrap(err, "close during rotation") 87 | } 88 | if err := os.Remove(f.Name()); err != nil { 89 | return errors.Wrap(err, "remove file") 90 | } 91 | f, err = os.Create(w.path) 92 | if err != nil { 93 | return errors.Wrap(err, "rotate into new file") 94 | } 95 | bw = bufio.NewWriter(f) 96 | atomic.AddInt64(&w.rotations, 1) 97 | } 98 | } 99 | return nil 100 | } 101 | -------------------------------------------------------------------------------- /pkg/event/iterator_integration_test.go: -------------------------------------------------------------------------------- 1 | package event 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/require" 11 | 12 | changelogpkg "github.com/segmentio/ctlstore/pkg/changelog" 13 | "github.com/segmentio/ctlstore/pkg/ldb" 14 | "github.com/segmentio/ctlstore/pkg/ldbwriter" 15 | "github.com/segmentio/ctlstore/pkg/logwriter" 16 | "github.com/segmentio/ctlstore/pkg/schema" 17 | "github.com/segmentio/ctlstore/pkg/sqlite" 18 | "github.com/segmentio/ctlstore/pkg/tests" 19 | ) 20 | 21 | func TestIteratorIntegration(t *testing.T) { 22 | ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) 23 | defer cancel() 24 | 25 | f, teardown := tests.WithTmpFile(t, "changelog") 26 | defer teardown() 27 | 28 | changeBuffer := new(sqlite.SQLChangeBuffer) 29 | driverName := fmt.Sprintf("%s_%d", ldb.LDBDatabaseDriver, time.Now().UnixNano()) 30 | 31 | ldbTmpPath, teardown := ldb.NewLDBTmpPath(t) 32 | defer teardown() 33 | 34 | dsn := fmt.Sprintf("file:%s", ldbTmpPath) 35 | 36 | err := sqlite.RegisterSQLiteWatch(driverName, changeBuffer) 37 | require.NoError(t, err) 38 | 39 | db, err := sql.Open(driverName, dsn) 40 | if err != nil { 41 | t.Fatalf("Couldn't open SQLite db, error %v", err) 42 | } 43 | err = ldb.EnsureLdbInitialized(context.Background(), db) 44 | if err != nil { 45 | t.Fatalf("Couldn't initialize SQLite db, error %v", err) 46 | } 47 | 48 | sqlWriter := &ldbwriter.SqlLdbWriter{Db: db} 49 | sizedLogWriter := &logwriter.SizedLogWriter{Path: f.Name(), FileMode: 0644, RotateSize: 1024 * 1024} 50 | changeLogWriter := &changelogpkg.ChangelogWriter{WriteLine: sizedLogWriter} 51 | writer := &ldbwriter.LDBWriterWithChangelog{ 52 | LdbWriter: sqlWriter, 53 | ChangelogWriter: changeLogWriter, 54 | DB: db, 55 | ChangeBuffer: changeBuffer, 56 | } 57 | 58 | const numChanges = 50 59 | 60 | go func() { 61 | err := writer.ApplyDMLStatement(ctx, schema.NewTestDMLStatement("CREATE TABLE fam___foo (id int primary key not null, val VARCHAR);")) 62 | require.NoError(t, err) 63 | 64 | for i := 0; i < numChanges; i++ { 65 | err = writer.ApplyDMLStatement(ctx, schema.NewTestDMLStatement(fmt.Sprintf("INSERT INTO fam___foo VALUES(%d, 'hello');", i))) 66 | require.NoError(t, err) 67 | } 68 | }() 69 | 70 | iter, err := NewIterator(ctx, f.Name()) 71 | require.NoError(t, err) 72 | require.NotNil(t, iter) 73 | 74 | for i := 0; i < numChanges; i++ { 75 | e, err := iter.Next(ctx) 76 | require.NoError(t, err) 77 | require.EqualValues(t, i+1, e.Sequence) 78 | update := e.RowUpdate 79 | require.Equal(t, "fam", update.FamilyName) 80 | require.Equal(t, "foo", update.TableName) 81 | keys := update.Keys 82 | require.EqualValues(t, []Key{{ 83 | Name: "id", 84 | Type: "INT", 85 | Value: float64(i), 86 | }}, keys) 87 | 88 | } 89 | 90 | } 91 | -------------------------------------------------------------------------------- /pkg/event/iterator_test.go: -------------------------------------------------------------------------------- 1 | package event 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/pkg/errors" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestIterator(t *testing.T) { 13 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 14 | defer cancel() 15 | 16 | const numEvents = 5 17 | 18 | changelog := &fakeChangelog{} 19 | for i := 0; i < numEvents; i++ { 20 | changelog.ers = append(changelog.ers, eventErr{ 21 | event: Event{Sequence: int64(i)}, 22 | }) 23 | } 24 | iter, err := NewIterator(ctx, "test file", func(i *Iterator) { 25 | i.changelog = changelog 26 | }) 27 | require.NoError(t, err) 28 | defer func() { 29 | err := iter.Close() 30 | require.NoError(t, err) 31 | }() 32 | for i := 0; i < numEvents; i++ { 33 | event, err := iter.Next(ctx) 34 | require.NoError(t, err) 35 | require.EqualValues(t, i, event.Sequence) 36 | } 37 | } 38 | 39 | func TestFilteredIterator(t *testing.T) { 40 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 41 | defer cancel() 42 | 43 | const numEvents = 10 44 | 45 | var fam, tbl string 46 | fam = "numbers" 47 | changelog := &fakeChangelog{} 48 | for i := 0; i < numEvents; i++ { 49 | if i%2 == 0 { 50 | tbl = "even" 51 | } else { 52 | tbl = "odd" 53 | } 54 | changelog.ers = append(changelog.ers, eventErr{ 55 | event: Event{ 56 | Sequence: int64(i), 57 | RowUpdate: RowUpdate{ 58 | FamilyName: fam, 59 | TableName: tbl, 60 | Keys: nil, 61 | }, 62 | }, 63 | }) 64 | } 65 | 66 | iter, err := NewFilteredIterator(ctx, "test file", "numbers", "even", func(i *Iterator) { 67 | i.changelog = changelog 68 | }) 69 | require.NoError(t, err) 70 | defer func() { 71 | err := iter.Close() 72 | require.NoError(t, err) 73 | }() 74 | for i := 0; i < numEvents/2; i++ { 75 | event, err := iter.Next(ctx) 76 | require.NoError(t, err) 77 | require.EqualValues(t, i*2, event.Sequence) 78 | } 79 | } 80 | 81 | func TestIteratorFailedChangelogStart(t *testing.T) { 82 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 83 | defer cancel() 84 | 85 | iter, err := NewIterator(ctx, "test file", func(i *Iterator) { 86 | i.changelog = &fakeChangelog{ 87 | startErr: errors.New("failure"), // force a failure on startup 88 | ers: []eventErr{ 89 | {event: Event{Sequence: 0}}, 90 | {event: Event{Sequence: 3}}, 91 | {event: Event{Sequence: 4}}, 92 | }, 93 | } 94 | }) 95 | require.Nil(t, iter) 96 | require.EqualError(t, err, "start changelog: failure") 97 | 98 | } 99 | 100 | func TestIteratorSkippedEvent(t *testing.T) { 101 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 102 | defer cancel() 103 | 104 | iter, err := NewIterator(ctx, "test file", func(i *Iterator) { 105 | i.changelog = &fakeChangelog{ 106 | ers: []eventErr{ 107 | {event: Event{Sequence: 0}}, 108 | {event: Event{Sequence: 3}}, 109 | {event: Event{Sequence: 4}}, 110 | }, 111 | } 112 | }) 113 | require.NoError(t, err) 114 | defer func() { 115 | err := iter.Close() 116 | require.NoError(t, err) 117 | }() 118 | event, err := iter.Next(ctx) 119 | require.NoError(t, err) 120 | require.EqualValues(t, 0, event.Sequence) 121 | 122 | event, err = iter.Next(ctx) 123 | require.EqualValues(t, 3, event.Sequence) 124 | require.EqualError(t, err, "out of sync with changelog. invalidate caches please.") 125 | 126 | event, err = iter.Next(ctx) 127 | require.NoError(t, err) 128 | require.EqualValues(t, 4, event.Sequence) 129 | } 130 | 131 | func TestFilteredIteratorSkippedEvent(t *testing.T) { 132 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 133 | defer cancel() 134 | 135 | iter, err := NewFilteredIterator(ctx, "test file", "foo", "bar", func(i *Iterator) { 136 | i.changelog = &fakeChangelog{ 137 | ers: []eventErr{ 138 | {event: Event{Sequence: 0}}, 139 | {event: Event{Sequence: 3}}, 140 | {event: Event{Sequence: 4}}, 141 | }, 142 | } 143 | }) 144 | require.NoError(t, err) 145 | defer func() { 146 | err := iter.Close() 147 | require.NoError(t, err) 148 | }() 149 | // even if fam/tbl filter does not match we need to return errors 150 | event, err := iter.Next(ctx) 151 | require.EqualValues(t, 3, event.Sequence) 152 | require.EqualError(t, err, "out of sync with changelog. invalidate caches please.") 153 | } 154 | -------------------------------------------------------------------------------- /pkg/executive/db_info.go: -------------------------------------------------------------------------------- 1 | package executive 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | 8 | "github.com/go-sql-driver/mysql" 9 | mysql2 "github.com/segmentio/ctlstore/pkg/mysql" 10 | "github.com/segmentio/ctlstore/pkg/schema" 11 | sqlite2 "github.com/segmentio/ctlstore/pkg/sqlite" 12 | sqlite "github.com/segmentio/go-sqlite3" 13 | ) 14 | 15 | type sqlDBInfo interface { 16 | GetColumnInfo(ctx context.Context, tableNames []string) ([]schema.DBColumnInfo, error) 17 | GetAllTables(ctx context.Context) ([]schema.FamilyTable, error) 18 | } 19 | 20 | func getDBInfo(db *sql.DB) sqlDBInfo { 21 | switch t := db.Driver().(type) { 22 | case *mysql.MySQLDriver: 23 | return &mysql2.MySQLDBInfo{Db: db} 24 | case *sqlite.SQLiteDriver: 25 | return &sqlite2.SqliteDBInfo{Db: db} 26 | default: 27 | panic(fmt.Sprintf("Invalid driver type %T", t)) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /pkg/executive/dml_ledger_writer.go: -------------------------------------------------------------------------------- 1 | package executive 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | 7 | "github.com/segmentio/ctlstore/pkg/errs" 8 | "github.com/segmentio/ctlstore/pkg/schema" 9 | "github.com/segmentio/ctlstore/pkg/sqlgen" 10 | "github.com/segmentio/stats/v4" 11 | ) 12 | 13 | // Writes DML entries to log table within an existing transaction. Make 14 | // sure to call Close() after finishing. 15 | type dmlLedgerWriter struct { 16 | Tx *sql.Tx 17 | TableName string 18 | _stmt *sql.Stmt 19 | } 20 | 21 | func (w *dmlLedgerWriter) BeginTx(ctx context.Context) (seq schema.DMLSequence, err error) { 22 | return w.Add(ctx, schema.DMLTxBeginKey) 23 | } 24 | 25 | func (w *dmlLedgerWriter) CommitTx(ctx context.Context) (seq schema.DMLSequence, err error) { 26 | return w.Add(ctx, schema.DMLTxEndKey) 27 | } 28 | 29 | // Writes an entry to the DML log, returning the sequence or an error 30 | // if any occurs. 31 | func (w *dmlLedgerWriter) Add(ctx context.Context, statement string) (seq schema.DMLSequence, err error) { 32 | if w._stmt == nil { 33 | qs := sqlgen.SqlSprintf("INSERT INTO $1 (statement) VALUES(?)", w.TableName) 34 | stmt, err := w.Tx.PrepareContext(ctx, qs) 35 | if err != nil { 36 | errs.Incr("dml_ledger_writer.prepare.error") 37 | return 0, err 38 | } 39 | w._stmt = stmt 40 | } 41 | 42 | res, err := w._stmt.ExecContext(ctx, statement) 43 | if err != nil { 44 | errs.Incr("dml_ledger_writer.exec.error") 45 | return 46 | } 47 | stats.Incr("dml_ledger_writer.exec.success") 48 | 49 | seqId, err := res.LastInsertId() 50 | if err != nil { 51 | errs.Incr("dml_ledger_writer.last_insert_id_error") 52 | return 53 | } 54 | 55 | seq = schema.DMLSequence(seqId) 56 | return 57 | } 58 | 59 | func (w *dmlLedgerWriter) Close() error { 60 | if w._stmt != nil { 61 | return w._stmt.Close() 62 | } 63 | return nil 64 | } 65 | -------------------------------------------------------------------------------- /pkg/executive/dml_ledger_writer_test.go: -------------------------------------------------------------------------------- 1 | package executive 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/segmentio/ctlstore/pkg/schema" 9 | ) 10 | 11 | func TestDMLLogWriterAdd(t *testing.T) { 12 | suite := []struct { 13 | desc string 14 | statements []string 15 | }{ 16 | { 17 | desc: "Insert the first statement", 18 | statements: []string{ 19 | "INSERT INTO foo_bar VALUES('x', 'y', 123)", 20 | }, 21 | }, 22 | } 23 | 24 | for i, testCase := range suite { 25 | testName := fmt.Sprintf("[%d] %s", i, testCase.desc) 26 | t.Run(testName, func(t *testing.T) { 27 | t.Parallel() 28 | ctx := context.Background() 29 | db, teardown := newCtlDBTestConnection(t, "mysql") 30 | defer teardown() 31 | 32 | tx, err := db.BeginTx(ctx, nil) 33 | if err != nil { 34 | t.Fatalf("Unexpected error: %v", err) 35 | } 36 | defer tx.Rollback() 37 | 38 | w := &dmlLedgerWriter{ 39 | Tx: tx, 40 | TableName: "ctlstore_dml_ledger", 41 | } 42 | defer w.Close() 43 | 44 | seqs := []schema.DMLSequence{} 45 | for _, stString := range testCase.statements { 46 | seq, err := w.Add(ctx, stString) 47 | if err != nil { 48 | t.Fatalf("Unexpected error: %v", err) 49 | } 50 | seqs = append(seqs, seq) 51 | } 52 | 53 | row := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM ctlstore_dml_ledger") 54 | if err != nil { 55 | t.Fatalf("Unexpected error: %v", err) 56 | } 57 | 58 | var cnt int 59 | err = row.Scan(&cnt) 60 | if err != nil { 61 | t.Fatalf("Unexpected error: %v", err) 62 | } 63 | 64 | if want, got := 0, cnt; want != got { 65 | t.Errorf("Expected row count to start at %d, got %d", want, got) 66 | } 67 | 68 | err = tx.Commit() 69 | if err != nil { 70 | t.Fatalf("Unexpected error: %v", err) 71 | } 72 | 73 | rows, err := db.QueryContext(ctx, 74 | "SELECT seq, statement "+ 75 | "FROM ctlstore_dml_ledger "+ 76 | "ORDER BY seq ASC") 77 | if err != nil { 78 | t.Fatalf("Unexpected error: %v", err) 79 | } 80 | defer rows.Close() 81 | 82 | i := 0 83 | for rows.Next() { 84 | if i+1 > len(testCase.statements) { 85 | t.Errorf("Scanned more statements than expected") 86 | break 87 | } 88 | 89 | var rowSeq int64 90 | var rowStmt string 91 | 92 | err = rows.Scan(&rowSeq, &rowStmt) 93 | if err != nil { 94 | t.Fatalf("Unexpected error: %v", err) 95 | } 96 | 97 | if want, got := seqs[i], schema.DMLSequence(rowSeq); want != got { 98 | t.Errorf("Expected %v, got %v", want, got) 99 | } 100 | 101 | if want, got := testCase.statements[i], rowStmt; want != got { 102 | t.Errorf("Expected %v, got %v", want, got) 103 | } 104 | 105 | i++ 106 | } 107 | 108 | }) 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /pkg/executive/executive.go: -------------------------------------------------------------------------------- 1 | package executive 2 | 3 | import ( 4 | "github.com/pkg/errors" 5 | "github.com/segmentio/ctlstore/pkg/limits" 6 | "github.com/segmentio/ctlstore/pkg/schema" 7 | ) 8 | 9 | const ( 10 | // DefaultExecutiveURL is where the executive can be reached 11 | DefaultExecutiveURL = "ctlstore-executive.segment.local" 12 | ) 13 | 14 | type ExecutiveMutationRequest struct { 15 | TableName string 16 | Delete bool 17 | Values map[string]interface{} 18 | } 19 | 20 | //counterfeiter:generate -o fakes/executive_interface.go . ExecutiveInterface 21 | type ExecutiveInterface interface { 22 | CreateFamily(familyName string) error 23 | CreateTable(familyName string, tableName string, fieldNames []string, fieldTypes []schema.FieldType, keyFields []string) error 24 | CreateTables([]schema.Table) error 25 | AddFields(familyName string, tableName string, fieldNames []string, fieldTypes []schema.FieldType) error 26 | 27 | Mutate(writerName string, writerSecret string, familyName string, cookie []byte, checkCookie []byte, requests []ExecutiveMutationRequest) error 28 | GetWriterCookie(writerName string, writerSecret string) ([]byte, error) 29 | SetWriterCookie(writerName string, writerSecret string, cookie []byte) error 30 | RegisterWriter(writerName string, writerSecret string) error 31 | 32 | TableSchema(familyName string, tableName string) (*schema.Table, error) 33 | FamilySchemas(familyName string) ([]schema.Table, error) 34 | 35 | ReadRow(familyName string, tableName string, where map[string]interface{}) (map[string]interface{}, error) 36 | 37 | ReadTableSizeLimits() (limits.TableSizeLimits, error) 38 | UpdateTableSizeLimit(limit limits.TableSizeLimit) error 39 | DeleteTableSizeLimit(table schema.FamilyTable) error 40 | 41 | ReadWriterRateLimits() (limits.WriterRateLimits, error) 42 | UpdateWriterRateLimit(limit limits.WriterRateLimit) error 43 | DeleteWriterRateLimit(writerName string) error 44 | 45 | ClearTable(table schema.FamilyTable) error 46 | DropTable(table schema.FamilyTable) error 47 | ReadFamilyTableNames(familyName schema.FamilyName) ([]schema.FamilyTable, error) 48 | } 49 | 50 | type mutationRequest struct { 51 | FamilyName schema.FamilyName 52 | TableName schema.TableName 53 | Delete bool 54 | Values map[schema.FieldName]interface{} 55 | } 56 | 57 | func newMutationRequest(famName schema.FamilyName, req ExecutiveMutationRequest) (mutationRequest, error) { 58 | tblName, err := schema.NewTableName(req.TableName) 59 | if err != nil { 60 | return mutationRequest{}, nil 61 | } 62 | 63 | vals := map[schema.FieldName]interface{}{} 64 | for name, val := range req.Values { 65 | fn, err := schema.NewFieldName(name) 66 | if err != nil { 67 | return mutationRequest{}, err 68 | } 69 | vals[fn] = val 70 | } 71 | 72 | return mutationRequest{ 73 | FamilyName: famName, 74 | TableName: tblName, 75 | Delete: req.Delete, 76 | Values: vals, 77 | }, nil 78 | } 79 | 80 | // Returns the request Values as a slice in the order specified by the 81 | //fieldOrder param. An error will be returned if a field is missing. 82 | func (r *mutationRequest) valuesByOrder(fieldOrder []schema.FieldName) ([]interface{}, error) { 83 | values := []interface{}{} 84 | for _, fn := range fieldOrder { 85 | if v, ok := r.Values[fn]; ok { 86 | values = append(values, v) 87 | } else { 88 | return nil, errors.Errorf("Missing field %s", fn) 89 | } 90 | } 91 | return values, nil 92 | } 93 | 94 | type mutationRequestSet struct { 95 | Requests []mutationRequest 96 | } 97 | 98 | func newMutationRequestSet(famName schema.FamilyName, exReqs []ExecutiveMutationRequest) (mutationRequestSet, error) { 99 | reqs := make([]mutationRequest, len(exReqs)) 100 | for i, exReq := range exReqs { 101 | req, err := newMutationRequest(famName, exReq) 102 | if err != nil { 103 | return mutationRequestSet{}, err 104 | } 105 | reqs[i] = req 106 | } 107 | return mutationRequestSet{reqs}, nil 108 | } 109 | 110 | // Return the unique set of table names as a O(1) lookup map 111 | func (s *mutationRequestSet) TableNameSet() map[schema.TableName]struct{} { 112 | tnset := map[schema.TableName]struct{}{} 113 | for _, req := range s.Requests { 114 | tnset[req.TableName] = struct{}{} 115 | } 116 | return tnset 117 | } 118 | 119 | // Return the unique set of table names as a slice 120 | func (s *mutationRequestSet) TableNames() []schema.TableName { 121 | tns := []schema.TableName{} 122 | for tableName := range s.TableNameSet() { 123 | tns = append(tns, tableName) 124 | } 125 | return tns 126 | } 127 | -------------------------------------------------------------------------------- /pkg/executive/fake_time.go: -------------------------------------------------------------------------------- 1 | package executive 2 | 3 | import ( 4 | "sync/atomic" 5 | "time" 6 | ) 7 | 8 | // fakeTime is a type that produces time.Times to second precision 9 | type fakeTime struct { 10 | epoch int64 11 | } 12 | 13 | func newFakeTime(epoch int64) *fakeTime { 14 | ft := &fakeTime{} 15 | ft.set(epoch) 16 | return ft 17 | } 18 | 19 | func (t *fakeTime) set(epoch int64) { 20 | atomic.StoreInt64(&t.epoch, epoch) 21 | } 22 | 23 | func (t *fakeTime) get() time.Time { 24 | val := atomic.LoadInt64(&t.epoch) 25 | return time.Unix(val, 0) 26 | } 27 | 28 | func (t *fakeTime) add(delta int64) { 29 | atomic.AddInt64(&t.epoch, delta) 30 | } 31 | -------------------------------------------------------------------------------- /pkg/executive/generate.go: -------------------------------------------------------------------------------- 1 | package executive 2 | 3 | //go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate 4 | -------------------------------------------------------------------------------- /pkg/executive/health.go: -------------------------------------------------------------------------------- 1 | package executive 2 | 3 | type HealthChecker interface { 4 | HealthCheck() error 5 | } 6 | -------------------------------------------------------------------------------- /pkg/executive/mutators_store_test.go: -------------------------------------------------------------------------------- 1 | package executive 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "testing" 7 | 8 | "github.com/segmentio/ctlstore/pkg/schema" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestMutatorStoreExists(t *testing.T) { 13 | ctx := context.Background() 14 | db, teardown := newCtlDBTestConnection(t, "mysql") 15 | defer teardown() 16 | tx, err := db.BeginTx(ctx, nil) 17 | require.NoError(t, err) 18 | defer tx.Rollback() 19 | ms := mutatorStore{DB: tx, Ctx: ctx, TableName: "mutators"} 20 | checkExists := func(name string, expected bool) { 21 | exists, err := ms.Exists(schema.WriterName{Name: name}) 22 | require.NoError(t, err) 23 | require.Equal(t, expected, exists) 24 | } 25 | checkExists("my-writer", false) 26 | require.NoError(t, ms.Register(schema.WriterName{Name: "my-writer"}, "my-secret")) 27 | checkExists("my-writer", true) 28 | } 29 | 30 | func TestMutatorStoreRegisterAndGet(t *testing.T) { 31 | ctx := context.Background() 32 | db, teardown := newCtlDBTestConnection(t, "mysql") 33 | defer teardown() 34 | 35 | tx1, err := db.BeginTx(ctx, nil) 36 | require.NoError(t, err) 37 | 38 | defer tx1.Rollback() 39 | 40 | ms1 := mutatorStore{ 41 | DB: tx1, 42 | Ctx: ctx, 43 | TableName: "mutators", 44 | } 45 | 46 | // the hash of "" is the secret that the DB that the writer is seeded with. so 47 | // this register should return successfully since it is a no-op 48 | err = ms1.Register(schema.WriterName{Name: "writer1"}, "") 49 | require.NoError(t, err) 50 | 51 | err = ms1.Register(schema.WriterName{Name: "writer1"}, "different-password") 52 | require.Equal(t, ErrWriterAlreadyExists, err) 53 | 54 | tx2, err := db.BeginTx(ctx, nil) 55 | require.NoError(t, err) 56 | defer tx2.Rollback() 57 | 58 | ms2 := mutatorStore{ 59 | DB: tx2, 60 | Ctx: ctx, 61 | TableName: "mutators", 62 | } 63 | 64 | _, ok, err := ms2.Get(schema.WriterName{Name: "writerNotFound"}, "") 65 | require.False(t, ok) 66 | } 67 | 68 | func TestMutatorStoreUpdate(t *testing.T) { 69 | suite := []struct { 70 | desc string 71 | writerName string 72 | cookie []byte 73 | ifCookie []byte 74 | expectErr error 75 | }{ 76 | { 77 | desc: "Overwrite existing cookie", 78 | writerName: "writer1", 79 | cookie: []byte{0}, 80 | }, 81 | { 82 | desc: "Check-and-set existing cookie success", 83 | writerName: "writer1", 84 | cookie: []byte{2}, 85 | ifCookie: []byte{1}, 86 | }, 87 | { 88 | desc: "Check-and-set existing cookie conflict", 89 | writerName: "writer1", 90 | cookie: []byte{2}, 91 | ifCookie: []byte{0}, 92 | expectErr: ErrCookieConflict, 93 | }, 94 | { 95 | desc: "Set to same succeeds", 96 | writerName: "writer1", 97 | cookie: []byte{1}, 98 | ifCookie: []byte{1}, 99 | }, 100 | { 101 | desc: "Send super long cookie", 102 | writerName: "writer1", 103 | cookie: bytes.Repeat([]byte{0}, 1025), 104 | expectErr: ErrCookieTooLong, 105 | }, 106 | { 107 | desc: "Send super long if cookie", 108 | writerName: "writer1", 109 | cookie: []byte{1}, 110 | ifCookie: bytes.Repeat([]byte{0}, 1025), 111 | expectErr: ErrCookieTooLong, 112 | }, 113 | { 114 | desc: "Non-existant writer", 115 | writerName: "writer100", 116 | expectErr: ErrWriterNotFound, 117 | }, 118 | } 119 | 120 | for _, testCase := range suite { 121 | t.Run(testCase.desc, func(t *testing.T) { 122 | ctx := context.Background() 123 | db, teardown := newCtlDBTestConnection(t, "mysql") 124 | defer teardown() 125 | 126 | tx, err := db.BeginTx(ctx, nil) 127 | if err != nil { 128 | t.Fatalf("Unexpected error: %v", err) 129 | } 130 | defer tx.Rollback() 131 | 132 | ms := mutatorStore{ 133 | DB: tx, 134 | Ctx: ctx, 135 | TableName: "mutators", 136 | } 137 | 138 | err = ms.Update( 139 | schema.WriterName{Name: testCase.writerName}, 140 | "", 141 | testCase.cookie, 142 | testCase.ifCookie) 143 | 144 | if want, got := testCase.expectErr, err; want != got { 145 | t.Errorf("Expected error %v, got %v", want, got) 146 | } 147 | }) 148 | } 149 | 150 | } 151 | -------------------------------------------------------------------------------- /pkg/executive/sql.go: -------------------------------------------------------------------------------- 1 | package executive 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "strings" 7 | 8 | _ "github.com/segmentio/go-sqlite3" // gives us sqlite3 everywhere 9 | ) 10 | 11 | // SQLDBClient allows generalizing several database/sql types 12 | type SQLDBClient interface { 13 | ExecContext(context.Context, string, ...interface{}) (sql.Result, error) 14 | PrepareContext(context.Context, string) (*sql.Stmt, error) 15 | QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) 16 | QueryRowContext(context.Context, string, ...interface{}) *sql.Row 17 | } 18 | 19 | func errorIsRowConflict(err error) bool { 20 | return strings.Contains(err.Error(), "Duplicate entry") || 21 | strings.Contains(err.Error(), "UNIQUE constraint failed") 22 | } 23 | -------------------------------------------------------------------------------- /pkg/executive/sql_test.go: -------------------------------------------------------------------------------- 1 | package executive 2 | 3 | import ( 4 | "net/url" 5 | "testing" 6 | 7 | "github.com/segmentio/ctlstore/pkg/ctldb" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestAddParameterToDSN(t *testing.T) { 12 | dsn := "http://foo.bar?x=y" 13 | updated, err := ctldb.AddParameterToDSN(dsn, "update", "newvalue") 14 | assert.NoError(t, err) 15 | parsed, err := url.Parse(updated) 16 | assert.NoError(t, err) 17 | if val := parsed.Query().Get("x"); val != "y" { 18 | t.Fatalf("old value should be 'y' but was '%s'", val) 19 | } 20 | if val := parsed.Query().Get("update"); val != "newvalue" { 21 | t.Fatalf("new value should be 'newvalue' but was '%s'", val) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /pkg/executive/status_writer.go: -------------------------------------------------------------------------------- 1 | package executive 2 | 3 | import "net/http" 4 | 5 | type statusWriter struct { 6 | writer http.ResponseWriter 7 | code int 8 | } 9 | 10 | func (w *statusWriter) Header() http.Header { 11 | return w.writer.Header() 12 | } 13 | 14 | func (w *statusWriter) Write(b []byte) (int, error) { 15 | return w.writer.Write(b) 16 | } 17 | 18 | func (w *statusWriter) WriteHeader(statusCode int) { 19 | w.code = statusCode 20 | w.writer.WriteHeader(statusCode) 21 | } 22 | -------------------------------------------------------------------------------- /pkg/executive/test_executive.go: -------------------------------------------------------------------------------- 1 | package executive 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "io/ioutil" 7 | "net" 8 | "net/http" 9 | "os" 10 | "os/signal" 11 | "path/filepath" 12 | "sync" 13 | "syscall" 14 | "time" 15 | 16 | "github.com/segmentio/ctlstore/pkg/ctldb" 17 | "github.com/segmentio/ctlstore/pkg/limits" 18 | "github.com/segmentio/ctlstore/pkg/sqlgen" 19 | "github.com/segmentio/ctlstore/pkg/units" 20 | "github.com/segmentio/events/v2" 21 | ) 22 | 23 | type TestExecutiveService struct { 24 | Addr net.Addr 25 | ctldb *sql.DB 26 | tmpDir string 27 | h *http.Server 28 | } 29 | 30 | func NewTestExecutiveService(bindTo string) (*TestExecutiveService, error) { 31 | tmpDir, err := ioutil.TempDir("", "") 32 | if err != nil { 33 | return nil, err 34 | } 35 | 36 | db, err := sql.Open("sqlite3", filepath.Join(tmpDir, "ctldb.db")) 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | err = ctldb.InitializeCtlDB(db, sqlgen.SqlDriverToDriverName) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | svc := &TestExecutiveService{ 47 | tmpDir: tmpDir, 48 | ctldb: db, 49 | } 50 | 51 | stop := make(chan os.Signal, 1) 52 | signal.Notify(stop, os.Interrupt, syscall.SIGTERM) 53 | svc.h = &http.Server{Handler: svc} 54 | 55 | started := sync.WaitGroup{} 56 | started.Add(2) 57 | 58 | go func() { 59 | listener, err := net.Listen("tcp", bindTo) 60 | if err != nil { 61 | events.Log("Error listening: %{error}+v", err) 62 | started.Done() 63 | return 64 | } 65 | 66 | // Allows for getting the port after random port assignment 67 | svc.Addr = listener.Addr() 68 | started.Done() 69 | 70 | if err := svc.h.Serve(listener); err != nil && err != http.ErrServerClosed { 71 | events.Log("Error serving: %{error}+v", err) 72 | } 73 | }() 74 | 75 | go func() { 76 | started.Done() 77 | 78 | <-stop 79 | svc.shutdown() 80 | }() 81 | 82 | started.Wait() 83 | 84 | return svc, nil 85 | } 86 | 87 | func (s *TestExecutiveService) shutdown() { 88 | sctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 89 | defer cancel() 90 | 91 | if err := s.h.Shutdown(sctx); err != nil { 92 | events.Log("Shutdown error: %{error}+v", err) 93 | } 94 | } 95 | 96 | func (s *TestExecutiveService) ServeHTTP(w http.ResponseWriter, r *http.Request) { 97 | ctx := context.Background() 98 | 99 | // Setup and tear these down every req to limit thread-safety garbage 100 | cR := r.WithContext(ctx) 101 | limiter := newDBLimiter( 102 | s.ctldb, 103 | "sqlite3", limits.SizeLimits{ 104 | MaxSize: 100 * units.MEGABYTE, 105 | WarnSize: 50 * units.MEGABYTE, 106 | }, 107 | time.Second, 108 | 1000, 109 | ) 110 | exec := &dbExecutive{DB: s.ctldb, Ctx: ctx, limiter: limiter} 111 | ep := ExecutiveEndpoint{Exec: exec, HealthChecker: exec} 112 | defer ep.Close() 113 | ep.Handler().ServeHTTP(w, cR) 114 | } 115 | 116 | func (s *TestExecutiveService) Close() error { 117 | s.shutdown() 118 | s.ctldb.Close() 119 | os.RemoveAll(s.tmpDir) 120 | return nil 121 | } 122 | 123 | func (s *TestExecutiveService) ExecutiveInterface() ExecutiveInterface { 124 | return &dbExecutive{DB: s.ctldb, Ctx: context.Background()} 125 | } 126 | -------------------------------------------------------------------------------- /pkg/executive/test_executive_test.go: -------------------------------------------------------------------------------- 1 | package executive 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | func TestTestExecutiveService(t *testing.T) { 10 | svc, err := NewTestExecutiveService("127.0.0.1:0") 11 | if err != nil { 12 | t.Fatalf("Unexpected error: %v", err) 13 | } 14 | 15 | { 16 | resp, err := http.Get("http://" + svc.Addr.String() + "/status") 17 | if err != nil { 18 | t.Fatalf("Unexpected error: %v", err) 19 | } 20 | 21 | if want, got := 200, resp.StatusCode; want != got { 22 | t.Errorf("Expected status code: %v, got %v", want, got) 23 | } 24 | } 25 | 26 | { 27 | ior := strings.NewReader("") 28 | resp, err := http.Post("http://"+svc.Addr.String()+"/families/test1", "text/plain", ior) 29 | if err != nil { 30 | t.Fatalf("Unexpected error: %v", err) 31 | } 32 | 33 | if want, got := 200, resp.StatusCode; want != got { 34 | t.Errorf("Expected status code: %v, got %v", want, got) 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /pkg/globalstats/stats_test.go: -------------------------------------------------------------------------------- 1 | package globalstats 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sort" 7 | "sync" 8 | "testing" 9 | "time" 10 | 11 | "github.com/segmentio/stats/v4" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | type fakeHandler struct { 16 | mut sync.Mutex 17 | measuresByName map[string][]stats.Measure 18 | } 19 | 20 | // fakeHandler needs to conform to the stats.Handler interface. 21 | var _ stats.Handler = &fakeHandler{} 22 | 23 | func newFakeHandler() *fakeHandler { 24 | return &fakeHandler{ 25 | measuresByName: make(map[string][]stats.Measure), 26 | } 27 | } 28 | 29 | func (h *fakeHandler) HandleMeasures(t time.Time, measures ...stats.Measure) { 30 | h.mut.Lock() 31 | defer h.mut.Unlock() 32 | 33 | for _, m := range measures { 34 | if _, ok := h.measuresByName[m.Name]; !ok { 35 | h.measuresByName[m.Name] = []stats.Measure{} 36 | } 37 | h.measuresByName[m.Name] = append(h.measuresByName[m.Name], m.Clone()) 38 | } 39 | } 40 | 41 | func TestGlobalStats(t *testing.T) { 42 | t.Skip("This test is flapping. Skipping until it can be remediated.") 43 | ctx, cancel := context.WithCancel(context.Background()) 44 | 45 | // Overwrite the default engine with a testing mock. 46 | h := newFakeHandler() 47 | originalHandler := stats.DefaultEngine.Handler 48 | stats.DefaultEngine.Handler = h 49 | defer func() { 50 | // Replace the default DefaultEngine.Handler 51 | stats.DefaultEngine.Handler = originalHandler 52 | }() 53 | 54 | // Initialize globalstats. 55 | Initialize(ctx, Config{ 56 | FlushEvery: 10 * time.Millisecond, 57 | }) 58 | 59 | // Perform some example Incr operations. 60 | Incr("a", "family-a", "table-a") 61 | Incr("b", "family-a", "table-a") 62 | Incr("a", "family-a", "table-a") 63 | Incr("a", "family-a", "table-b") 64 | 65 | // Wait for the Incr operations to propogate to the flusher. 66 | time.Sleep(15 * time.Millisecond) 67 | cancel() 68 | 69 | h.mut.Lock() 70 | defer h.mut.Unlock() 71 | 72 | // Verify that the three Incr operations were flushed. 73 | flusherMeasures, ok := h.measuresByName["ctlstore.global"] 74 | require.True(t, ok) 75 | // Sort the measures we received so that we can reliably compare the output. 76 | sort.Slice(flusherMeasures, func(i, j int) bool { 77 | fi, fj := flusherMeasures[i], flusherMeasures[j] 78 | if len(fi.Fields) != len(fj.Fields) || len(fi.Fields) == 0 { 79 | return len(fi.Fields) < len(fj.Fields) 80 | } 81 | if fi.Fields[0].Value.Int() != fj.Fields[0].Value.Int() { 82 | return fi.Fields[0].Value.Int() < fj.Fields[0].Value.Int() 83 | } 84 | return fi.Fields[0].Name < fj.Fields[0].Name 85 | }) 86 | fmt.Printf("%+v\n", flusherMeasures) 87 | require.Equal(t, []stats.Measure{ 88 | { 89 | Name: "ctlstore.global", 90 | Fields: []stats.Field{ 91 | stats.MakeField("dropped-stats", 0, stats.Counter), 92 | }, 93 | Tags: []stats.Tag{ 94 | stats.T("app", "globalstats.test"), 95 | stats.T("version", "unknown"), 96 | }, 97 | }, 98 | { 99 | Name: "ctlstore.global", 100 | Fields: []stats.Field{ 101 | stats.MakeField("a", 1, stats.Counter), 102 | }, 103 | Tags: []stats.Tag{ 104 | stats.T("app", "globalstats.test"), 105 | stats.T("family", "family-a"), 106 | stats.T("table", "table-b"), 107 | stats.T("version", "unknown"), 108 | }, 109 | }, 110 | { 111 | Name: "ctlstore.global", 112 | Fields: []stats.Field{ 113 | stats.MakeField("b", 1, stats.Counter), 114 | }, 115 | Tags: []stats.Tag{ 116 | stats.T("app", "globalstats.test"), 117 | stats.T("family", "family-a"), 118 | stats.T("table", "table-a"), 119 | stats.T("version", "unknown"), 120 | }, 121 | }, 122 | { 123 | Name: "ctlstore.global", 124 | Fields: []stats.Field{ 125 | stats.MakeField("a", 2, stats.Counter), 126 | }, 127 | Tags: []stats.Tag{ 128 | stats.T("app", "globalstats.test"), 129 | stats.T("family", "family-a"), 130 | stats.T("table", "table-a"), 131 | stats.T("version", "unknown"), 132 | }, 133 | }, 134 | }, flusherMeasures) 135 | } 136 | -------------------------------------------------------------------------------- /pkg/heartbeat/heartbeat_test.go: -------------------------------------------------------------------------------- 1 | package heartbeat 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | "time" 10 | 11 | _ "github.com/go-sql-driver/mysql" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestHeartbeat(t *testing.T) { 16 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) 17 | defer cancel() 18 | rCh := make(chan *http.Request) 19 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 20 | fmt.Println("Received", r.URL.Path) 21 | switch r.URL.Path { 22 | case "/writers/writer-name": 23 | case "/families/my-family": 24 | case "/families/my-family/tables/my-table": 25 | case "/families/my-family/mutations": 26 | default: 27 | http.Error(w, fmt.Sprintf("invalid path: %s", r.URL.Path), http.StatusInternalServerError) 28 | t.Fatal("unexpected:", r.URL) 29 | } 30 | select { 31 | case rCh <- r: 32 | case <-ctx.Done(): 33 | t.Fatal(ctx.Err()) 34 | } 35 | })) 36 | defer server.Close() 37 | go func() { 38 | h, err := HeartbeatFromConfig(HeartbeatConfig{ 39 | Table: "my-table", 40 | Family: "my-family", 41 | ExecutiveURL: server.URL, 42 | WriterName: "writer-name", 43 | WriterSecret: "writer-secret", 44 | HeartbeatInterval: 10 * time.Hour, 45 | }) 46 | require.NoError(t, err) 47 | require.NotNil(t, h) 48 | defer h.Close() 49 | 50 | h.Start(ctx) 51 | }() 52 | 53 | nextRequest := func() *http.Request { 54 | select { 55 | case r := <-rCh: 56 | return r 57 | case <-ctx.Done(): 58 | t.Fatal(ctx.Err()) 59 | } 60 | panic("unreachable") 61 | } 62 | r := nextRequest() 63 | require.Equal(t, http.MethodPost, r.Method) 64 | require.Equal(t, "/writers/writer-name", r.URL.Path) 65 | 66 | r = nextRequest() 67 | require.Equal(t, http.MethodPost, r.Method) 68 | require.Equal(t, "/families/my-family", r.URL.Path) 69 | 70 | r = nextRequest() 71 | require.Equal(t, http.MethodPost, r.Method) 72 | require.Equal(t, "/families/my-family/tables/my-table", r.URL.Path) 73 | 74 | // verify at least one mutation was sent 75 | 76 | r = nextRequest() 77 | require.Equal(t, http.MethodPost, r.Method) 78 | require.Equal(t, "/families/my-family/mutations", r.URL.Path) 79 | } 80 | -------------------------------------------------------------------------------- /pkg/ldb/ldbs.go: -------------------------------------------------------------------------------- 1 | package ldb 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "io/ioutil" 8 | "os" 9 | "path/filepath" 10 | "sync/atomic" 11 | "testing" 12 | 13 | "github.com/segmentio/ctlstore/pkg/schema" 14 | ) 15 | 16 | const ( 17 | LDBSeqTableName = "_ldb_seq" 18 | LDBLastUpdateTableName = "_ldb_last_update" 19 | LDBLastLedgerUpdateColumn = "ledger" 20 | LDBSeqTableID = 1 21 | LDBDatabaseDriver = "sqlite3" 22 | DefaultLDBFilename = "ldb.db" 23 | ) 24 | 25 | var ( 26 | // SQL for fetching current tracked sequence 27 | ldbFetchSeqSQL = fmt.Sprintf(` 28 | SELECT seq FROM %s WHERE id = %d 29 | `, LDBSeqTableName, LDBSeqTableID) 30 | 31 | ldbInitializeDDLs = []string{ 32 | // Initialization DDL for table that tracks sequence position. Tried to avoid 33 | // a PK column but it makes updating the sequence monotonically messy. 34 | fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( 35 | id INTEGER PRIMARY KEY NOT NULL, 36 | seq BIGINT NOT NULL 37 | )`, LDBSeqTableName), 38 | fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( 39 | name STRING PRIMARY KEY NOT NULL, 40 | timestamp DATETIME NOT NULL 41 | )`, LDBLastUpdateTableName), 42 | } 43 | ) 44 | 45 | var testTmpSeq int64 = 0 46 | 47 | func LDBForTestWithPath(t testing.TB) (res *sql.DB, teardown func(), path string) { 48 | tmpDir, err := ioutil.TempDir("", "ldb-for-test") 49 | if err != nil { 50 | t.Fatal(err) 51 | } 52 | 53 | // Since there's a need for multiple TXs, have to use a tmp file 54 | // for the database. In-memory in shared-cache mode kinda works, 55 | // but it has aggressive locking that blocks the tests we want to do. 56 | path = NextTestLdbTmpPath(tmpDir) 57 | db, err := OpenLDB(path, "rwc") 58 | if err != nil { 59 | t.Fatalf("Couldn't open SQLite db, error %v", err) 60 | } 61 | err = EnsureLdbInitialized(context.Background(), db) 62 | if err != nil { 63 | t.Fatalf("Couldn't initialize SQLite db, error %v", err) 64 | } 65 | return db, func() { 66 | if tmpDir != "" { 67 | os.RemoveAll(tmpDir) 68 | } 69 | }, path 70 | } 71 | 72 | func LDBForTest(t testing.TB) (res *sql.DB, teardown func()) { 73 | x, y, _ := LDBForTestWithPath(t) 74 | return x, y 75 | } 76 | 77 | func OpenLDB(path string, mode string) (*sql.DB, error) { 78 | return sql.Open("sqlite3_with_autocheckpoint_off", 79 | fmt.Sprintf("file:%s?_journal_mode=wal&mode=%s", path, mode)) 80 | } 81 | 82 | func OpenImmutableLDB(path string) (*sql.DB, error) { 83 | return sql.Open("sqlite3_with_autocheckpoint_off", fmt.Sprintf("file:%s?immutable=true", path)) 84 | } 85 | 86 | // Ensures the LDB is prepared for queries 87 | func EnsureLdbInitialized(ctx context.Context, db *sql.DB) error { 88 | for _, statement := range ldbInitializeDDLs { 89 | if _, err := db.ExecContext(ctx, statement); err != nil { 90 | return err 91 | } 92 | } 93 | return nil 94 | } 95 | 96 | func NewLDBTmpPath(t *testing.T) (string, func()) { 97 | path, err := ioutil.TempDir("", "ldb-tmp-path") 98 | if err != nil { 99 | t.Fatal(err) 100 | } 101 | dbPath := filepath.Join(path, "ldb.db") 102 | return dbPath, func() { 103 | if path != "" { 104 | os.RemoveAll(path) 105 | } 106 | } 107 | } 108 | 109 | func NextTestLdbTmpPath(testTmpDir string) string { 110 | nextSeq := atomic.AddInt64(&testTmpSeq, 1) 111 | return fmt.Sprintf("%s/ldbForTest%d.db", testTmpDir, nextSeq) 112 | } 113 | 114 | // Gets current sequence from provided db 115 | func FetchSeqFromLdb(ctx context.Context, db *sql.DB) (schema.DMLSequence, error) { 116 | row := db.QueryRowContext(ctx, ldbFetchSeqSQL) 117 | var seq int64 118 | err := row.Scan(&seq) 119 | if err == sql.ErrNoRows { 120 | return schema.DMLSequence(0), nil 121 | } 122 | return schema.DMLSequence(seq), err 123 | } 124 | -------------------------------------------------------------------------------- /pkg/ldbwriter/changelog_callback.go: -------------------------------------------------------------------------------- 1 | package ldbwriter 2 | 3 | import ( 4 | "context" 5 | "sync/atomic" 6 | 7 | "github.com/segmentio/ctlstore/pkg/changelog" 8 | "github.com/segmentio/ctlstore/pkg/schema" 9 | "github.com/segmentio/events/v2" 10 | ) 11 | 12 | type ChangelogCallback struct { 13 | ChangelogWriter *changelog.ChangelogWriter 14 | Seq int64 15 | } 16 | 17 | func (c *ChangelogCallback) LDBWritten(ctx context.Context, data LDBWriteMetadata) { 18 | for _, change := range data.Changes { 19 | fam, tbl, err := schema.DecodeLDBTableName(change.TableName) 20 | if err != nil { 21 | // This is expected because it'll capture tables like ctlstore_dml_ledger, 22 | // which aren't tables this cares about. 23 | events.Debug("Skipped logging change to %{tableName}s, can't decode table: %{error}v", 24 | change.TableName, 25 | err) 26 | continue 27 | } 28 | 29 | keys, err := change.ExtractKeys(data.DB) 30 | if err != nil { 31 | events.Log("Skipped logging change to %{tableName}, can't extract keys: %{error}v", 32 | change.TableName, 33 | err) 34 | continue 35 | } 36 | 37 | for _, key := range keys { 38 | seq := atomic.AddInt64(&c.Seq, 1) 39 | err = c.ChangelogWriter.WriteChange(changelog.ChangelogEntry{ 40 | Seq: seq, 41 | Family: fam.Name, 42 | Table: tbl.Name, 43 | Key: key, 44 | }) 45 | if err != nil { 46 | events.Log("Skipped logging change to %{family}s.%{table}s:%{key}v: %{err}v", 47 | fam, tbl, key, err) 48 | continue 49 | } 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /pkg/ldbwriter/ldb_callback_writer.go: -------------------------------------------------------------------------------- 1 | package ldbwriter 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | 7 | "github.com/segmentio/ctlstore/pkg/schema" 8 | "github.com/segmentio/ctlstore/pkg/sqlite" 9 | "github.com/segmentio/events/v2" 10 | ) 11 | 12 | // CallbackWriter is an LDBWriter that delegates to another 13 | // writer and then, upon a successful write, executes N callbacks. 14 | type CallbackWriter struct { 15 | DB *sql.DB 16 | Delegate LDBWriter 17 | Callbacks []LDBWriteCallback 18 | ChangeBuffer *sqlite.SQLChangeBuffer 19 | } 20 | 21 | func (w *CallbackWriter) ApplyDMLStatement(ctx context.Context, statement schema.DMLStatement) error { 22 | err := w.Delegate.ApplyDMLStatement(ctx, statement) 23 | if err != nil { 24 | return err 25 | } 26 | changes := w.ChangeBuffer.Pop() 27 | for _, callback := range w.Callbacks { 28 | events.Debug("Writing DML callback for %{cb}T", callback) 29 | callback.LDBWritten(ctx, LDBWriteMetadata{ 30 | DB: w.DB, 31 | Statement: statement, 32 | Changes: changes, 33 | }) 34 | } 35 | return nil 36 | } 37 | -------------------------------------------------------------------------------- /pkg/ldbwriter/ldb_writer_with_changelog.go: -------------------------------------------------------------------------------- 1 | package ldbwriter 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "sync/atomic" 7 | 8 | "github.com/segmentio/ctlstore/pkg/changelog" 9 | "github.com/segmentio/ctlstore/pkg/schema" 10 | "github.com/segmentio/ctlstore/pkg/sqlite" 11 | "github.com/segmentio/events/v2" 12 | ) 13 | 14 | type LDBWriterWithChangelog struct { 15 | LdbWriter LDBWriter 16 | ChangelogWriter *changelog.ChangelogWriter 17 | DB *sql.DB 18 | ChangeBuffer *sqlite.SQLChangeBuffer 19 | Seq int64 20 | } 21 | 22 | // 23 | // NOTE: How does the changelog work? 24 | // 25 | // This is sort of the crux of how the changelog comes together. The Reflector 26 | // sets a pre-update hook which populates a channel with any changes that happen 27 | // in the LDB. These changes end up on a buffered channel. After each statement 28 | // is executed, the pre-update hook will get called, filling in the channel. Once 29 | // that ApplyDMLStatement returns, the DML statement is committed and the channel 30 | // contains the contents of the update. Then this function takes over, extracts 31 | // the keys from the update, and writes them to the changelogWriter. 32 | // 33 | // This is pretty complex, but after enumerating about 8 different options, it 34 | // ended up actually being the most simple. Other options involved not-so-great 35 | // options like parsing SQL or maintaining triggers on every table. 36 | // 37 | func (w *LDBWriterWithChangelog) ApplyDMLStatement(ctx context.Context, statement schema.DMLStatement) error { 38 | err := w.LdbWriter.ApplyDMLStatement(ctx, statement) 39 | if err != nil { 40 | return err 41 | } 42 | 43 | for _, change := range w.ChangeBuffer.Pop() { 44 | fam, tbl, err := schema.DecodeLDBTableName(change.TableName) 45 | if err != nil { 46 | // This is expected because it'll capture tables like ctlstore_dml_ledger, 47 | // which aren't tables this cares about. 48 | events.Debug("Skipped logging change to %{tableName}s, can't decode table: %{error}v", 49 | change.TableName, 50 | err) 51 | continue 52 | } 53 | 54 | keys, err := change.ExtractKeys(w.DB) 55 | if err != nil { 56 | events.Log("Skipped logging change to %{tableName}, can't extract keys: %{error}v", 57 | change.TableName, 58 | err) 59 | continue 60 | } 61 | 62 | for _, key := range keys { 63 | seq := atomic.AddInt64(&w.Seq, 1) 64 | err = w.ChangelogWriter.WriteChange(changelog.ChangelogEntry{ 65 | Seq: seq, 66 | Family: fam.Name, 67 | Table: tbl.Name, 68 | Key: key, 69 | }) 70 | if err != nil { 71 | events.Log("Skipped logging change to %{family}s.%{table}s:%{key}v: %{err}v", 72 | fam, tbl, key, err) 73 | continue 74 | } 75 | } 76 | } 77 | return nil 78 | } 79 | -------------------------------------------------------------------------------- /pkg/ledger/ecs_client.go: -------------------------------------------------------------------------------- 1 | package ledger 2 | 3 | import "github.com/aws/aws-sdk-go/service/ecs/ecsiface" 4 | 5 | //counterfeiter:generate -o fakes/ecs_client.go . ECSClient 6 | type ECSClient interface { 7 | ecsiface.ECSAPI 8 | } 9 | -------------------------------------------------------------------------------- /pkg/ledger/ecs_metadata.go: -------------------------------------------------------------------------------- 1 | package ledger 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/pkg/errors" 7 | ) 8 | 9 | type EcsMetadata struct { 10 | ContainerInstanceArn string 11 | Cluster string 12 | } 13 | 14 | // accountID parses the container instance arn and returns the account id portion 15 | // 16 | // ex: "arn:aws:ecs:us-west-2:[accountId]:container-instance/[instance-id]" 17 | func (m EcsMetadata) accountID() (string, error) { 18 | parts := strings.Split(m.ContainerInstanceArn, ":") 19 | if len(parts) != 6 { 20 | return "", errors.Errorf("invalid container instance arn: '%s'", m.ContainerInstanceArn) 21 | } 22 | return parts[4], nil 23 | } 24 | -------------------------------------------------------------------------------- /pkg/ledger/fake_ticker.go: -------------------------------------------------------------------------------- 1 | package ledger 2 | 3 | import ( 4 | "context" 5 | "time" 6 | ) 7 | 8 | // FakeTicker allows us to manually control when a send happens 9 | // on the channel. The Ticker property allows us to adhere to 10 | // the *time.Ticker interface. 11 | type FakeTicker struct { 12 | Ticker *time.Ticker 13 | ch chan time.Time 14 | } 15 | 16 | func NewFakeTicker() *FakeTicker { 17 | ch := make(chan time.Time) 18 | return &FakeTicker{ 19 | ch: ch, 20 | Ticker: &time.Ticker{ 21 | C: ch, 22 | }, 23 | } 24 | } 25 | 26 | func (f *FakeTicker) Tick(ctx context.Context) { 27 | select { 28 | case f.ch <- time.Now(): 29 | case <-ctx.Done(): 30 | } 31 | } 32 | 33 | func (f *FakeTicker) Stop() { 34 | close(f.ch) 35 | } 36 | -------------------------------------------------------------------------------- /pkg/ledger/fake_ticker_test.go: -------------------------------------------------------------------------------- 1 | package ledger_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/segmentio/ctlstore/pkg/ledger" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestFakeTicker(t *testing.T) { 12 | ft := ledger.NewFakeTicker() 13 | go func() { 14 | defer ft.Stop() 15 | ft.Tick(context.Background()) 16 | ft.Tick(context.Background()) 17 | }() 18 | count := 0 19 | for range ft.Ticker.C { 20 | count++ 21 | } 22 | require.Equal(t, 2, count) 23 | } 24 | -------------------------------------------------------------------------------- /pkg/ledger/generate.go: -------------------------------------------------------------------------------- 1 | package ledger 2 | 3 | //go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate 4 | -------------------------------------------------------------------------------- /pkg/ledger/opts.go: -------------------------------------------------------------------------------- 1 | package ledger 2 | 3 | import "time" 4 | 5 | func WithCheckCallback(fn func()) MonitorOpt { 6 | return func(m *Monitor) { 7 | m.checkCallback = fn 8 | } 9 | } 10 | 11 | func WithECSClient(ecsClient ECSClient) MonitorOpt { 12 | return func(m *Monitor) { 13 | m.ecsClient = ecsClient 14 | } 15 | } 16 | 17 | func WithECSMetadataFunc(fn ecsMetadataFunc) MonitorOpt { 18 | return func(m *Monitor) { 19 | m.ecsMetadataFunc = fn 20 | } 21 | } 22 | 23 | func WithTicker(ticker *time.Ticker) MonitorOpt { 24 | return func(m *Monitor) { 25 | m.tickerFunc = func() *time.Ticker { 26 | return ticker 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /pkg/limits/limits.go: -------------------------------------------------------------------------------- 1 | package limits 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "math" 7 | "time" 8 | 9 | "github.com/pkg/errors" 10 | "github.com/segmentio/ctlstore/pkg/units" 11 | ) 12 | 13 | const ( 14 | LimitRequestBodySize = 1 * units.MEGABYTE 15 | LimitMaxDMLSize = 768 * units.KILOBYTE 16 | LimitFieldValueSize = 512 * units.KILOBYTE 17 | 18 | LimitMaxMutateRequestCount = 100 19 | LimitWriterCookieSize = 1024 20 | 21 | LimitWriterSecretMaxLength = 100 22 | LimitWriterSecretMinLength = 3 23 | ) 24 | 25 | // TableSizeLimits is a representation of all of the table size limits 26 | type TableSizeLimits struct { 27 | Global SizeLimits `json:"global"` 28 | Tables []TableSizeLimit `json:"tables"` 29 | } 30 | 31 | // TableSizeLimit represents the limit for a particular table 32 | type TableSizeLimit struct { 33 | SizeLimits 34 | Family string `json:"family"` 35 | Table string `json:"table"` 36 | } 37 | 38 | // SizeLimits composes a max and a warn size 39 | type SizeLimits struct { 40 | MaxSize int64 `json:"max-size"` 41 | WarnSize int64 `json:"warn-size"` 42 | } 43 | 44 | // WriterRateLimits represents all of the writer limits 45 | type WriterRateLimits struct { 46 | Global RateLimit `json:"global"` 47 | Writers []WriterRateLimit `json:"writers"` 48 | } 49 | 50 | // WriterRateLimit represents the limit for a particular writer 51 | type WriterRateLimit struct { 52 | Writer string `json:"writer"` 53 | RateLimit RateLimit `json:"rate-limit"` 54 | } 55 | 56 | // RateLimit composes an amount allowed per duration 57 | type RateLimit struct { 58 | Amount int64 `json:"amount"` 59 | Period time.Duration `json:"period"` 60 | } 61 | 62 | // UnmarshalJSON allows us to deser time.Durations using string values 63 | func (l *RateLimit) UnmarshalJSON(b []byte) error { 64 | var val map[string]interface{} 65 | if err := json.Unmarshal(b, &val); err != nil { 66 | return err 67 | } 68 | if amount, ok := val["amount"]; ok { 69 | switch amount := amount.(type) { 70 | case float64: 71 | l.Amount = int64(amount) 72 | default: 73 | return errors.Errorf("invalid amount: '%v'", amount) 74 | } 75 | } 76 | if period, ok := val["period"]; ok { 77 | switch period := period.(type) { 78 | case float64: 79 | l.Period = time.Duration(int64(period)) 80 | case string: 81 | parsed, err := time.ParseDuration(period) 82 | if err != nil { 83 | return errors.Errorf("invalid period: '%v'", period) 84 | } 85 | l.Period = parsed 86 | default: 87 | return errors.Errorf("invalid period: '%v'", period) 88 | } 89 | } 90 | return nil 91 | } 92 | 93 | func (l RateLimit) String() string { 94 | return fmt.Sprintf("%d/%v", l.Amount, l.Period) 95 | } 96 | 97 | // adjustAmount adjusts the composed amount for the specified period, rounded to the nearest second 98 | func (l RateLimit) AdjustAmount(period time.Duration) (int64, error) { 99 | if period.Seconds() <= 0 { 100 | return 0, errors.New("supplied period must be positive") 101 | } 102 | scaling := l.Period.Seconds() / period.Seconds() 103 | amount := float64(l.Amount) / scaling 104 | return int64(math.RoundToEven(amount)), nil 105 | } 106 | -------------------------------------------------------------------------------- /pkg/logwriter/sized_log_writer.go: -------------------------------------------------------------------------------- 1 | package logwriter 2 | 3 | import ( 4 | "os" 5 | "strings" 6 | 7 | "github.com/pkg/errors" 8 | ) 9 | 10 | const sizedLogWriterDefaultMode os.FileMode = 0644 11 | 12 | // Implements a line-by-line log file writer that appends to a file 13 | // specified by Path until it reaches RotateSize bytes, at which point 14 | // it will delete the file and start over with a fresh one. 15 | // 16 | // Make sure to call Close() after this is no longer needed. 17 | type SizedLogWriter struct { 18 | RotateSize int 19 | Path string 20 | FileMode os.FileMode 21 | 22 | _f *os.File // don't use this directly, use file() 23 | } 24 | 25 | func (w *SizedLogWriter) Mode() os.FileMode { 26 | if w.FileMode == 0 { 27 | return sizedLogWriterDefaultMode 28 | } 29 | 30 | return w.FileMode 31 | } 32 | 33 | func (w *SizedLogWriter) File() (*os.File, error) { 34 | if w._f != nil { 35 | return w._f, nil 36 | } 37 | 38 | f, err := os.OpenFile(w.Path, os.O_CREATE|os.O_RDWR, w.Mode()) 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | w._f = f 44 | return w._f, nil 45 | } 46 | 47 | func (w *SizedLogWriter) Rotate() error { 48 | var err error 49 | 50 | if w._f != nil { 51 | err = w._f.Close() 52 | if err != nil { 53 | return err 54 | } 55 | w._f = nil 56 | } 57 | 58 | err = os.Remove(w.Path) 59 | if err != nil { 60 | return err 61 | } 62 | 63 | return nil 64 | } 65 | 66 | // Close cleans up the associated resources 67 | func (w *SizedLogWriter) Close() error { 68 | if w._f != nil { 69 | err := w._f.Close() 70 | w._f = nil 71 | return err 72 | } 73 | return nil 74 | } 75 | 76 | // WriteLine appends a line to the end of the log file. If the log line would 77 | // exceed the set RotateSize, then the log file will be rotated, and the line 78 | // will be appended to the new log file. 79 | func (w *SizedLogWriter) WriteLine(line string) error { 80 | f, err := w.File() 81 | if err != nil { 82 | return err 83 | } 84 | 85 | if strings.ContainsRune(line, '\n') { 86 | return errors.New("Lines can't contain a carriage-return") 87 | } 88 | 89 | if len(line) > w.RotateSize { 90 | return errors.New("Line length is > RotateSize") 91 | } 92 | 93 | offset, err := f.Seek(0, os.SEEK_END) 94 | if err != nil { 95 | return err 96 | } 97 | 98 | newEndOffset := offset + int64(len(line)) 99 | if newEndOffset > int64(w.RotateSize) { 100 | err = w.Rotate() 101 | if err != nil { 102 | return err 103 | } 104 | 105 | f, err = w.File() 106 | if err != nil { 107 | return err 108 | } 109 | } 110 | 111 | bytes := []byte(line) 112 | bytes = append(bytes, byte('\n')) 113 | _, err = f.Write(bytes) 114 | if err != nil { 115 | return err 116 | } 117 | 118 | return nil 119 | } 120 | -------------------------------------------------------------------------------- /pkg/logwriter/sized_log_writer_test.go: -------------------------------------------------------------------------------- 1 | package logwriter 2 | 3 | import ( 4 | "io/ioutil" 5 | "os" 6 | "testing" 7 | 8 | "github.com/google/go-cmp/cmp" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func newSLWTestPath(t *testing.T) (path string, teardown func()) { 13 | f, err := ioutil.TempFile("", "sized-log-writer-test") 14 | require.NoError(t, err) 15 | return f.Name(), func() { 16 | os.Remove(f.Name()) 17 | } 18 | } 19 | 20 | func TestSizedLogWriterCreatesFile(t *testing.T) { 21 | path, teardown := newSLWTestPath(t) 22 | defer teardown() 23 | w := SizedLogWriter{ 24 | RotateSize: 100000, 25 | Path: path, 26 | } 27 | defer w.Close() 28 | 29 | w.WriteLine("hello") 30 | 31 | bytes, err := ioutil.ReadFile(path) 32 | if err != nil { 33 | t.Fatalf("Unexpected error: %+v", err) 34 | } 35 | 36 | if diff := cmp.Diff([]byte("hello\n"), bytes); diff != "" { 37 | t.Errorf("Bytes differ\n%v", diff) 38 | } 39 | } 40 | 41 | func TestSizedLogWriterAppendsToExistingFile(t *testing.T) { 42 | path, teardown := newSLWTestPath(t) 43 | defer teardown() 44 | err := ioutil.WriteFile(path, []byte("line1\n"), 0644) 45 | if err != nil { 46 | t.Fatalf("Unexpected error: %+v", err) 47 | } 48 | 49 | w := SizedLogWriter{ 50 | RotateSize: 100000, 51 | Path: path, 52 | } 53 | defer w.Close() 54 | 55 | w.WriteLine("line2") 56 | 57 | bytes, err := ioutil.ReadFile(path) 58 | if err != nil { 59 | t.Fatalf("Unexpected error: %+v", err) 60 | } 61 | 62 | if diff := cmp.Diff([]byte("line1\nline2\n"), bytes); diff != "" { 63 | t.Errorf("Bytes differ\n%v", diff) 64 | } 65 | } 66 | 67 | func TestSizedLogWriterRotatesFile(t *testing.T) { 68 | path, teardown := newSLWTestPath(t) 69 | defer teardown() 70 | err := ioutil.WriteFile(path, []byte("1234567890\n"), 0644) 71 | if err != nil { 72 | t.Fatalf("Unexpected error: %+v", err) 73 | } 74 | 75 | w := SizedLogWriter{ 76 | RotateSize: 21, // chosen so it will rotate right at the third 77 | Path: path, 78 | } 79 | defer w.Close() 80 | 81 | w.WriteLine("1234567890") 82 | w.WriteLine("1234567890") 83 | 84 | bytes, err := ioutil.ReadFile(path) 85 | if err != nil { 86 | t.Fatalf("Unexpected error: %+v", err) 87 | } 88 | 89 | if diff := cmp.Diff([]byte("1234567890\n"), bytes); diff != "" { 90 | t.Errorf("Bytes differ\n%v", diff) 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /pkg/mysql/mysql_info.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | 7 | "github.com/pkg/errors" 8 | "github.com/segmentio/ctlstore/pkg/schema" 9 | "github.com/segmentio/ctlstore/pkg/sqlgen" 10 | ) 11 | 12 | type MySQLDBInfo struct { 13 | Db *sql.DB 14 | } 15 | 16 | func (m *MySQLDBInfo) GetAllTables(ctx context.Context) ([]schema.FamilyTable, error) { 17 | var res []schema.FamilyTable 18 | rows, err := m.Db.QueryContext(ctx, "select distinct table_name from information_schema.tables order by table_name") 19 | if err != nil { 20 | return nil, errors.Wrap(err, "query table names") 21 | } 22 | for rows.Next() { 23 | var fullName string 24 | err = rows.Scan(&fullName) 25 | if err != nil { 26 | return nil, errors.Wrap(err, "scan table name") 27 | } 28 | if ft, ok := schema.ParseFamilyTable(fullName); ok { 29 | res = append(res, ft) 30 | } 31 | 32 | } 33 | return res, err 34 | } 35 | 36 | func (m *MySQLDBInfo) GetColumnInfo(ctx context.Context, tableNames []string) ([]schema.DBColumnInfo, error) { 37 | if len(tableNames) == 0 { 38 | return nil, nil 39 | } 40 | 41 | qs := sqlgen.SqlSprintf( 42 | "SELECT table_name, ordinal_position, column_name, data_type, column_key "+ 43 | "FROM information_schema.columns "+ 44 | "WHERE table_name IN ($1) "+ 45 | "AND table_schema = DATABASE() "+ 46 | "ORDER BY table_name, ordinal_position ASC", 47 | sqlgen.SQLPlaceholderSet(len(tableNames))) 48 | 49 | // []interface{} below won't accept []string 50 | ptrTableNames := []interface{}{} 51 | for _, tableName := range tableNames { 52 | ptrTableNames = append(ptrTableNames, tableName) 53 | } 54 | 55 | rows, err := m.Db.QueryContext(ctx, qs, ptrTableNames...) 56 | if err != nil { 57 | return nil, err 58 | } 59 | defer rows.Close() 60 | 61 | columnInfos := []schema.DBColumnInfo{} 62 | 63 | for rows.Next() { 64 | var tableName string 65 | var index int 66 | var colName string 67 | var dataType string 68 | var colKey string 69 | 70 | err = rows.Scan( 71 | &tableName, 72 | &index, 73 | &colName, 74 | &dataType, 75 | &colKey, 76 | ) 77 | if err != nil { 78 | return nil, err 79 | } 80 | 81 | columnInfos = append(columnInfos, schema.DBColumnInfo{ 82 | TableName: tableName, 83 | Index: index, 84 | ColumnName: colName, 85 | DataType: dataType, 86 | IsPrimaryKey: (colKey == "PRI"), 87 | }) 88 | } 89 | err = rows.Err() 90 | if err != nil { 91 | return nil, err 92 | } 93 | 94 | return columnInfos, nil 95 | } 96 | -------------------------------------------------------------------------------- /pkg/reflector/dml_source.go: -------------------------------------------------------------------------------- 1 | package reflector 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "time" 8 | 9 | "github.com/pkg/errors" 10 | "github.com/segmentio/ctlstore/pkg/schema" 11 | "github.com/segmentio/ctlstore/pkg/sqlgen" 12 | "github.com/segmentio/stats/v4" 13 | ) 14 | 15 | const ( 16 | defaultQueryBlockSize = 100 17 | dmlLedgerTimestampFormat = "2006-01-02 15:04:05" 18 | ) 19 | 20 | var errNoNewStatements = errors.New("No new statements") 21 | 22 | type dmlSource interface { 23 | Next(ctx context.Context) (schema.DMLStatement, error) 24 | // TODO: probably need a last sequence fetcher 25 | } 26 | 27 | // a dmlSource built on top of a database/sql instance 28 | type sqlDmlSource struct { 29 | db *sql.DB 30 | lastSequence schema.DMLSequence 31 | ledgerTableName string 32 | queryBlockSize int 33 | buffer []schema.DMLStatement 34 | scanLoopCallBack func() 35 | } 36 | 37 | // Next returns the next sequential statement in the source. If there are no 38 | // new statements, it returns errNoNewStatements. Any errors that occur while 39 | // fetching data will be returned as well. 40 | func (source *sqlDmlSource) Next(ctx context.Context) (statement schema.DMLStatement, err error) { 41 | if len(source.buffer) == 0 { 42 | blocksize := source.queryBlockSize 43 | if blocksize == 0 { 44 | blocksize = defaultQueryBlockSize 45 | } 46 | 47 | // table layout is: seq, leader_ts, statement 48 | qs := sqlgen.SqlSprintf("SELECT seq, leader_ts, statement FROM $1 WHERE seq > ? ORDER BY seq LIMIT $2", 49 | source.ledgerTableName, 50 | fmt.Sprintf("%d", blocksize)) 51 | 52 | // HMM: do we lean too hard on the LIMIT here? in the loop below 53 | // we'll end up spinning if the DB keeps feeding us data 54 | 55 | rows, err := source.db.QueryContext(ctx, qs, source.lastSequence) 56 | if err != nil { 57 | return statement, errors.Wrap(err, "select row") 58 | } 59 | 60 | // CR: reconsider naked returns here 61 | 62 | defer rows.Close() 63 | 64 | row := struct { 65 | seq int64 66 | leaderTs string // this is a string b/c the driver errors when trying to Scan into a *time.Time. 67 | statement string 68 | }{} 69 | 70 | for { 71 | if source.scanLoopCallBack != nil { 72 | source.scanLoopCallBack() 73 | } 74 | 75 | if !rows.Next() { 76 | break 77 | } 78 | 79 | err = rows.Scan(&row.seq, &row.leaderTs, &row.statement) 80 | if err != nil { 81 | return statement, errors.Wrap(err, "scan row") 82 | } 83 | 84 | if schema.DMLSequence(row.seq) > source.lastSequence+1 { 85 | stats.Incr("sql_dml_source.skipped_sequence") 86 | } 87 | 88 | timestamp, err := time.Parse(dmlLedgerTimestampFormat, row.leaderTs) 89 | if err != nil { 90 | return statement, errors.Wrapf(err, "could not parse time '%s'", row.leaderTs) 91 | } 92 | 93 | dmlst := schema.DMLStatement{ 94 | Sequence: schema.DMLSequence(row.seq), 95 | Statement: row.statement, 96 | Timestamp: timestamp, 97 | } 98 | 99 | source.buffer = append(source.buffer, dmlst) 100 | 101 | // if this doesn't get updated every time, say just doing the last row 102 | // after the iteration, an early return can cause lastSequence to diverge 103 | // from the buffer contents 104 | source.lastSequence = dmlst.Sequence 105 | } 106 | 107 | err = rows.Err() 108 | if err != nil { 109 | return statement, errors.Wrap(err, "rows err") 110 | } 111 | } 112 | 113 | // Still have to guard this case because source.buffer gets 114 | // mutated above, and certainly could add zero statements. 115 | if len(source.buffer) > 0 { 116 | // FIFO queue 117 | statement = source.buffer[0] 118 | source.buffer = source.buffer[1:] 119 | return 120 | } 121 | 122 | err = errNoNewStatements 123 | return 124 | } 125 | -------------------------------------------------------------------------------- /pkg/reflector/dml_source_test.go: -------------------------------------------------------------------------------- 1 | package reflector 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "testing" 8 | "time" 9 | 10 | "github.com/pkg/errors" 11 | "github.com/segmentio/ctlstore/pkg/limits" 12 | "github.com/segmentio/ctlstore/pkg/sqlgen" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | type sqlDmlSourceTestUtil struct { 17 | db *sql.DB 18 | t *testing.T 19 | } 20 | 21 | func (u *sqlDmlSourceTestUtil) InitializeDB() { 22 | _, err := u.db.Exec(sqlgen.SqlSprintf(` 23 | CREATE TABLE ctlstore_dml_ledger ( 24 | seq INTEGER PRIMARY KEY AUTOINCREMENT, 25 | leader_ts INTEGER NOT NULL DEFAULT CURRENT_TIMESTAMP, 26 | statement VARCHAR($1) 27 | ); 28 | INSERT INTO ctlstore_dml_ledger (statement) VALUES(''); 29 | DELETE FROM ctlstore_dml_ledger; 30 | `, fmt.Sprintf("%d", limits.LimitMaxDMLSize))) 31 | 32 | // the above bumps seq to 1 as starting value, since zero-values should 33 | // probably be avoided 34 | if err != nil { 35 | u.t.Fatalf("Failed to initialize DML Source DB, error: %v", err) 36 | } 37 | } 38 | 39 | func (u *sqlDmlSourceTestUtil) AddStatement(statement string) string { 40 | _, err := u.db.Exec("INSERT INTO ctlstore_dml_ledger (statement) VALUES(?)", statement) 41 | if err != nil { 42 | u.t.Fatalf("Failed to insert statement %v, error: %v", statement, err) 43 | } 44 | return statement 45 | } 46 | 47 | func TestSqlDmlSource(t *testing.T) { 48 | ctx := context.Background() 49 | db, err := sql.Open("sqlite3", ":memory:") 50 | require.NoError(t, err) 51 | 52 | srcutil := &sqlDmlSourceTestUtil{db: db, t: t} 53 | srcutil.InitializeDB() 54 | 55 | queryBlockSize := 5 56 | src := sqlDmlSource{ 57 | db: db, 58 | ledgerTableName: "ctlstore_dml_ledger", 59 | queryBlockSize: queryBlockSize, 60 | } 61 | 62 | _, err = src.Next(ctx) 63 | require.Equal(t, errNoNewStatements, err) 64 | 65 | var ststr string 66 | for i := 0; i < queryBlockSize*2; i++ { 67 | ststr = srcutil.AddStatement("INSERT INTO foo___bar VALUES('hi mom')") 68 | } 69 | 70 | var lastSeq int64 71 | for i := 0; i < queryBlockSize*2; i++ { 72 | st, err := src.Next(ctx) 73 | require.NoError(t, err) 74 | require.Equal(t, ststr, st.Statement) 75 | require.True(t, st.Sequence.Int() > lastSeq) 76 | lastSeq = st.Sequence.Int() 77 | } 78 | 79 | _, err = src.Next(ctx) 80 | require.Equal(t, errNoNewStatements, err) 81 | 82 | srcutil.AddStatement("INSERT INTO foo___bar VALUES('hi bro')") 83 | 84 | // Context cancellation handled properly 85 | ctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) 86 | defer cancel() 87 | loopCounter := 0 88 | src.scanLoopCallBack = func() { 89 | if loopCounter == 1 { 90 | cancel() 91 | } 92 | loopCounter++ 93 | } 94 | foundError := false 95 | for i := 0; i < 2; i++ { 96 | _, err = src.Next(ctx) 97 | cause := errors.Cause(err) 98 | switch { 99 | case cause == nil: 100 | case cause == context.Canceled: 101 | foundError = true 102 | break 103 | // the db driver will at some point return an error with 104 | // the value "interrupted" instead of returning 105 | // context.Canceled(). Sigh. 106 | case cause.Error() == "interrupted": 107 | foundError = true 108 | break 109 | } 110 | } 111 | if !foundError { 112 | t.Fatal("Expected a context error or an interrupted error") 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /pkg/reflector/download.go: -------------------------------------------------------------------------------- 1 | package reflector 2 | 3 | import ( 4 | "bytes" 5 | "compress/gzip" 6 | "io" 7 | "net/http" 8 | "strings" 9 | "time" 10 | 11 | "github.com/aws/aws-sdk-go/aws" 12 | "github.com/aws/aws-sdk-go/aws/awserr" 13 | "github.com/aws/aws-sdk-go/aws/session" 14 | "github.com/aws/aws-sdk-go/service/s3" 15 | "github.com/segmentio/errors-go" 16 | "github.com/segmentio/events/v2" 17 | "github.com/segmentio/stats/v4" 18 | 19 | "github.com/segmentio/ctlstore/pkg/errs" 20 | ) 21 | 22 | type downloadTo interface { 23 | DownloadTo(w io.Writer) (int64, error) 24 | } 25 | 26 | type S3Downloader struct { 27 | Region string // optional 28 | Bucket string 29 | Key string 30 | S3Client S3Client 31 | StartOverOnNotFound bool // whether we should rebuild LDB if snapshot not found 32 | } 33 | 34 | func (d *S3Downloader) DownloadTo(w io.Writer) (n int64, err error) { 35 | client, err := d.getS3Client() 36 | if err != nil { 37 | return -1, err 38 | } 39 | start := time.Now() 40 | defer func() { 41 | stats.Observe("snapshot_download_time", time.Now().Sub(start)) 42 | }() 43 | obj, err := client.GetObject(&s3.GetObjectInput{ 44 | Bucket: aws.String(d.Bucket), 45 | Key: aws.String(d.Key), 46 | }) 47 | if err != nil { 48 | switch err := err.(type) { 49 | case awserr.RequestFailure: 50 | if d.StartOverOnNotFound && err.StatusCode() == http.StatusNotFound { 51 | // don't bother retrying. we'll start with a fresh ldb. 52 | return -1, errors.WithTypes(errors.Wrap(err, "get s3 data"), errs.ErrTypePermanent) 53 | } 54 | } 55 | // retry 56 | return -1, errors.WithTypes(errors.Wrap(err, "get s3 data"), errs.ErrTypeTemporary) 57 | } 58 | defer obj.Body.Close() 59 | compressedSize := obj.ContentLength 60 | var reader io.Reader = obj.Body 61 | if strings.HasSuffix(d.Key, ".gz") { 62 | reader, err = gzip.NewReader(reader) 63 | if err != nil { 64 | return n, errors.Wrap(err, "create gzip reader") 65 | } 66 | } 67 | n, err = io.Copy(w, reader) 68 | if err != nil { 69 | return n, errors.Wrap(err, "copy from s3 to writer") 70 | } 71 | if compressedSize != nil { 72 | events.Log("LDB inflated %d -> %d bytes", *compressedSize, n) 73 | } 74 | 75 | return 76 | } 77 | 78 | func (d *S3Downloader) getS3Client() (S3Client, error) { 79 | if d.S3Client != nil { 80 | return d.S3Client, nil 81 | } 82 | configs := []*aws.Config{} 83 | if d.Region != "" { 84 | configs = append(configs, &aws.Config{ 85 | Region: aws.String(d.Region), 86 | }) 87 | } 88 | sess := session.Must(session.NewSession(configs...)) 89 | client := s3.New(sess) 90 | return client, nil 91 | } 92 | 93 | type memoryDownloader struct { 94 | Content []byte 95 | } 96 | 97 | func (d *memoryDownloader) DownloadTo(w io.Writer) (int64, error) { 98 | return io.Copy(w, bytes.NewReader(d.Content)) 99 | } 100 | -------------------------------------------------------------------------------- /pkg/reflector/fakes/fake_reflector.go: -------------------------------------------------------------------------------- 1 | package fakes 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/segmentio/ctlstore/pkg/utils" 7 | ) 8 | 9 | type ( 10 | FakeReflector struct { 11 | Running utils.AtomicBool 12 | Closed utils.AtomicBool 13 | Events chan string 14 | } 15 | ) 16 | 17 | func NewFakeReflector() *FakeReflector { 18 | return &FakeReflector{ 19 | Events: make(chan string, 1024), 20 | } 21 | } 22 | 23 | func (r *FakeReflector) NextEvent(ctx context.Context) string { 24 | select { 25 | case event := <-r.Events: 26 | return event 27 | case <-ctx.Done(): 28 | panic(ctx.Err()) 29 | } 30 | } 31 | 32 | func (r *FakeReflector) Start(ctx context.Context) error { 33 | r.Running.SetTrue() 34 | r.SendEvent("started") 35 | <-ctx.Done() 36 | r.Running.SetFalse() 37 | r.SendEvent("stopped") 38 | return ctx.Err() 39 | } 40 | 41 | func (r *FakeReflector) Stop() { 42 | 43 | } 44 | 45 | func (r *FakeReflector) Close() error { 46 | r.Closed.SetTrue() 47 | r.SendEvent("closed") 48 | return nil 49 | } 50 | 51 | func (r *FakeReflector) SendEvent(name string) { 52 | select { 53 | case r.Events <- name: 54 | default: 55 | panic("event chan full") 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /pkg/reflector/generate.go: -------------------------------------------------------------------------------- 1 | package reflector 2 | 3 | //go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate 4 | -------------------------------------------------------------------------------- /pkg/reflector/jitter.go: -------------------------------------------------------------------------------- 1 | package reflector 2 | 3 | import ( 4 | "math/rand" 5 | "time" 6 | ) 7 | 8 | type jitter struct { 9 | Rand *rand.Rand 10 | } 11 | 12 | func newJitter() *jitter { 13 | src := rand.NewSource(time.Now().UTC().UnixNano()) 14 | rnd := rand.New(src) 15 | return &jitter{rnd} 16 | } 17 | 18 | func (j *jitter) Jitter(dur time.Duration, coefficient float64) time.Duration { 19 | val := float64(dur) + (float64(dur) * (coefficient * (j.Rand.Float64() - 0.5) * 2.0)) 20 | if val < 0.0 { 21 | return 0.0 22 | } 23 | return time.Duration(val) 24 | } 25 | -------------------------------------------------------------------------------- /pkg/reflector/pipeline_test.go: -------------------------------------------------------------------------------- 1 | package reflector 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "testing" 7 | 8 | "github.com/segmentio/ctlstore" 9 | ldb2 "github.com/segmentio/ctlstore/pkg/ldb" 10 | "github.com/segmentio/ctlstore/pkg/ldbwriter" 11 | ) 12 | 13 | // Exercises the basic components of the DML source and LDB writer/reader 14 | func TestPipelineIntegration(t *testing.T) { 15 | var err error 16 | ctx := context.Background() 17 | ldb, err := sql.Open("sqlite3", ":memory:") 18 | if err != nil { 19 | t.Fatalf("Couldn't open LDB, error: %+v", err) 20 | } 21 | ctldb, err := sql.Open("sqlite3", ":memory:") 22 | if err != nil { 23 | t.Fatalf("Couldn't open ctldb, error: %+v", err) 24 | } 25 | 26 | err = ldb2.EnsureLdbInitialized(ctx, ldb) 27 | if err != nil { 28 | t.Fatalf("Couldn't initialize LDB, error: %+v", err) 29 | } 30 | 31 | srcutil := &sqlDmlSourceTestUtil{db: ctldb, t: t} 32 | srcutil.InitializeDB() 33 | dmlsrc := &sqlDmlSource{db: ctldb, ledgerTableName: "ctlstore_dml_ledger"} 34 | 35 | ldbw := ldbwriter.SqlLdbWriter{Db: ldb} 36 | ldbr := ctlstore.NewLDBReaderFromDB(ldb) 37 | 38 | applyAllStatements := func() { 39 | for { 40 | st, err := dmlsrc.Next(ctx) 41 | if err == errNoNewStatements { 42 | return 43 | } else if err != nil { 44 | t.Fatalf("error reading statements from DML source, error: %+v", err) 45 | } 46 | ldbw.ApplyDMLStatement(ctx, st) 47 | } 48 | } 49 | 50 | srcutil.AddStatement("CREATE TABLE foo___bar (key VARCHAR PRIMARY KEY, val VARCHAR)") 51 | applyAllStatements() 52 | 53 | row := struct { 54 | Key string `ctlstore:"key"` 55 | Val string `ctlstore:"val"` 56 | }{} 57 | 58 | { 59 | found, err := ldbr.GetRowByKey(ctx, &row, "foo", "bar", "zzz") 60 | if err != nil { 61 | t.Errorf("Unexpected error reading from LDB: %+v", err) 62 | } 63 | if found { 64 | t.Error("Expected to not find any rows before we INSERT something") 65 | } 66 | } 67 | 68 | srcutil.AddStatement("INSERT INTO foo___bar VALUES('zzz', 'yyy')") 69 | applyAllStatements() 70 | 71 | { 72 | _, err := ldbr.GetRowByKey(ctx, &row, "foo", "bar", "zzz") 73 | if err != nil { 74 | t.Errorf("Unexpected error reading from LDB: %+v", err) 75 | } 76 | if row.Key != "zzz" { 77 | t.Errorf("Unexpected row key %+v", row.Key) 78 | } 79 | if row.Val != "yyy" { 80 | t.Errorf("Unexpected row val %+v", row.Val) 81 | } 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /pkg/reflector/reflector_ctl_test.go: -------------------------------------------------------------------------------- 1 | package reflector 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/segmentio/ctlstore/pkg/reflector/fakes" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestReflectorCtlAppContextCloses(t *testing.T) { 13 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 14 | defer cancel() 15 | reflector := fakes.NewFakeReflector() 16 | ctl := NewReflectorCtl(reflector) 17 | 18 | // kill the context once it's started 19 | cancel() 20 | 21 | // verify that starting the reflector with a canceled context 22 | // does not actually start the reflector 23 | ctl.Start(ctx) 24 | time.Sleep(100 * time.Millisecond) 25 | require.EqualValues(t, 0, len(reflector.Events)) 26 | } 27 | 28 | func TestReflectorCtl(t *testing.T) { 29 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 30 | defer cancel() 31 | reflector := fakes.NewFakeReflector() 32 | ctl := NewReflectorCtl(reflector) 33 | 34 | // we should be able to stop the reflector before it starts. 35 | // these are no-ops. 36 | for i := 0; i < 5; i++ { 37 | ctl.Stop(ctx) 38 | } 39 | 40 | ctl.Start(ctx) 41 | 42 | // verify that the underlying reflector was started 43 | require.Equal(t, "started", reflector.NextEvent(ctx)) 44 | 45 | // once started, starting again should be a no-op 46 | for i := 0; i < 5; i++ { 47 | ctl.Start(ctx) 48 | } 49 | require.EqualValues(t, 0, len(reflector.Events)) 50 | 51 | ctl.Stop(ctx) 52 | require.Equal(t, "stopped", reflector.NextEvent(ctx)) 53 | 54 | // once stopped, stopping again should be a no-op 55 | for i := 0; i < 5; i++ { 56 | ctl.Stop(ctx) 57 | } 58 | require.EqualValues(t, 0, len(reflector.Events)) 59 | 60 | // restart and stop it again 61 | ctl.Start(ctx) 62 | require.Equal(t, "started", reflector.NextEvent(ctx)) 63 | ctl.Stop(ctx) 64 | require.Equal(t, "stopped", reflector.NextEvent(ctx)) 65 | 66 | // start it again, and then close it. 67 | ctl.Start(ctx) 68 | require.Equal(t, "started", reflector.NextEvent(ctx)) 69 | require.NoError(t, ctl.Close()) 70 | // closing the reflector involves first stopping it, 71 | // and then closing it. 72 | require.Equal(t, "stopped", reflector.NextEvent(ctx)) 73 | require.Equal(t, "closed", reflector.NextEvent(ctx)) 74 | 75 | // once closed, we should no longer be able to start 76 | // the reflector again. 77 | require.Panics(t, func() { 78 | ctl.Start(ctx) 79 | }) 80 | } 81 | 82 | // ensure that when a parent context is canceled it also cancels 83 | // the children contexts. cancel funcs only cancel their own 84 | // context root. 85 | func TestReflectorCtlContext(t *testing.T) { 86 | ctx, cancel := context.WithCancel(context.Background()) 87 | ctx2, _ := context.WithCancel(ctx) 88 | cancel() 89 | select { 90 | case <-ctx2.Done(): 91 | default: 92 | t.Fatal("context should have been canceled") 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /pkg/reflector/s3_client.go: -------------------------------------------------------------------------------- 1 | package reflector 2 | 3 | import ( 4 | "github.com/aws/aws-sdk-go/service/s3/s3iface" 5 | ) 6 | 7 | //counterfeiter:generate -o ../fakes/s3_client.go . S3Client 8 | type S3Client interface { 9 | s3iface.S3API 10 | } 11 | -------------------------------------------------------------------------------- /pkg/reflector/shovel.go: -------------------------------------------------------------------------------- 1 | package reflector 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "time" 7 | 8 | "github.com/segmentio/ctlstore/pkg/errs" 9 | "github.com/segmentio/ctlstore/pkg/ldbwriter" 10 | "github.com/segmentio/ctlstore/pkg/schema" 11 | "github.com/segmentio/errors-go" 12 | "github.com/segmentio/events/v2" 13 | "github.com/segmentio/stats/v4" 14 | ) 15 | 16 | type shovel struct { 17 | source dmlSource 18 | closers []io.Closer 19 | writer ldbwriter.LDBWriter 20 | pollInterval time.Duration 21 | pollTimeout time.Duration 22 | jitterCoefficient float64 23 | abortOnSeqSkip bool 24 | maxSeqOnStartup int64 25 | stop chan struct{} 26 | log *events.Logger 27 | } 28 | 29 | func (s *shovel) Start(ctx context.Context) error { 30 | jitr := newJitter() 31 | 32 | var cancel context.CancelFunc 33 | safeCancel := func() { 34 | if cancel != nil { 35 | cancel() 36 | } 37 | } 38 | 39 | var lastSeq schema.DMLSequence 40 | 41 | // Only actually close out the final cancel 42 | defer safeCancel() 43 | 44 | for { 45 | // early exit here if the shovel should be stopped 46 | select { 47 | case <-s.stop: 48 | s.logger().Log("Shovel stopping normally") 49 | return nil 50 | default: 51 | } 52 | 53 | // Need to clean up the cancel for each call of the loop, to avoid 54 | // leaking context. 55 | safeCancel() 56 | var sctx context.Context 57 | sctx, cancel = context.WithTimeout(ctx, s.pollTimeout) 58 | 59 | stats.Incr("shovel.loop_enter") 60 | s.logger().Debug("shovel polling...") 61 | st, err := s.source.Next(sctx) 62 | 63 | if err != nil { 64 | causeErr := errors.Cause(err) 65 | if causeErr != context.DeadlineExceeded && causeErr != errNoNewStatements { 66 | return err 67 | } 68 | 69 | if causeErr == context.DeadlineExceeded { 70 | errs.Incr("shovel.deadline_exceeded") 71 | } 72 | 73 | // 74 | // The sctx deadline will trigger the DeadlineExceeded err, which 75 | // would happen in the case that the backing store for the source 76 | // is slow. 77 | // 78 | // Otherwise, errNoNewStatements is a positive assertion that the 79 | // no new statements have been found. 80 | // 81 | 82 | pollSleep := jitr.Jitter(s.pollInterval, s.jitterCoefficient) 83 | s.logger().Debug("Poll sleep %{sleepTime}s", pollSleep) 84 | 85 | select { 86 | case <-ctx.Done(): 87 | return ctx.Err() 88 | case <-time.After(pollSleep): 89 | // sctx timeouts will fall through here, so we should probably 90 | // TODO: add exponential backoff for retries 91 | } 92 | continue 93 | } 94 | 95 | s.logger().Debug("Shovel applying %{statement}v", st) 96 | 97 | if lastSeq != 0 { 98 | if st.Sequence > lastSeq+1 && st.Sequence.Int() > s.maxSeqOnStartup { 99 | stats.Incr("shovel.skipped_sequence") 100 | s.logger().Log("shovel skip sequence from:%{fromSeq}d to:%{toSeq}d", lastSeq, st.Sequence) 101 | 102 | if s.abortOnSeqSkip { 103 | // Mitigation for a bug that we haven't found yet 104 | stats.Incr("shovel.skipped_sequence_abort") 105 | err = errors.New("shovel skipped sequence") 106 | err = errors.WithTypes(err, "SkippedSequence") 107 | return err 108 | } 109 | } 110 | } 111 | 112 | // there's actually a statement to work 113 | err = s.writer.ApplyDMLStatement(ctx, st) 114 | if err != nil { 115 | errs.Incr("shovel.apply_statement.error") 116 | return errors.Wrapf(err, "ledger seq: %d", st.Sequence) 117 | } 118 | 119 | lastSeq = st.Sequence 120 | 121 | stats.Incr("shovel.apply_statement.success") 122 | 123 | // check if the context is done each loop 124 | select { 125 | case <-ctx.Done(): 126 | return ctx.Err() 127 | default: 128 | // non-blocking 129 | } 130 | } 131 | } 132 | 133 | func (s *shovel) Close() error { 134 | for _, closer := range s.closers { 135 | err := closer.Close() 136 | if err != nil { 137 | s.logger().Log("shovel encountered error during close: %{error}s", err) 138 | } 139 | } 140 | return nil 141 | } 142 | 143 | func (s *shovel) logger() *events.Logger { 144 | if s.log == nil { 145 | s.log = events.DefaultLogger 146 | } 147 | return s.log 148 | } 149 | -------------------------------------------------------------------------------- /pkg/scanfunc/marshal.go: -------------------------------------------------------------------------------- 1 | package scanfunc 2 | 3 | import ( 4 | "errors" 5 | "reflect" 6 | "sync" 7 | 8 | "github.com/segmentio/ctlstore/pkg/unsafe" 9 | ) 10 | 11 | type ( 12 | UnmarshalMetaCache struct { 13 | cache map[reflect.Type]UnmarshalTypeMeta 14 | mu sync.RWMutex 15 | } 16 | UnmarshalTypeMeta struct { 17 | Fields map[string]UnmarshalTypeMetaField 18 | } 19 | UnmarshalTypeMetaField struct { 20 | Field reflect.StructField 21 | Factory unsafe.InterfaceFactory 22 | } 23 | UtmGetterFunc func(reflect.Type) (UnmarshalTypeMeta, error) 24 | ) 25 | 26 | var ErrUnmarshalUnsupportedType = errors.New("only map[string]interface{} and struct pointer types are supported for unmarshalling") 27 | var UtcNoopScanner interface{} = NoOpScanner{} 28 | var UtcCache = UnmarshalMetaCache{cache: map[reflect.Type]UnmarshalTypeMeta{}} 29 | 30 | func (umc *UnmarshalMetaCache) GetOrSet(typ reflect.Type, getter UtmGetterFunc) (UnmarshalTypeMeta, error) { 31 | umc.mu.RLock() 32 | defer umc.mu.RUnlock() 33 | 34 | if meta, ok := umc.cache[typ]; ok { 35 | return meta, nil 36 | } 37 | 38 | umc.mu.RUnlock() 39 | umc.mu.Lock() 40 | 41 | // There's a race here between the RUnlock and Lock, since another writer 42 | // could be waiting for write lock ahead of the line of this writer. That's 43 | // ok though because collecting this information isn't super expensive, and 44 | // the getter will always return the same thing for a type. The program 45 | // will quickly map out all the types used and this becomes basically 46 | // fixed. 47 | 48 | meta, err := getter(typ) 49 | if err == nil { 50 | umc.cache[typ] = meta 51 | } 52 | 53 | umc.mu.Unlock() 54 | umc.mu.RLock() 55 | return meta, err 56 | } 57 | 58 | func (umc *UnmarshalMetaCache) Invalidate(typ reflect.Type) { 59 | umc.mu.Lock() 60 | defer umc.mu.Unlock() 61 | delete(umc.cache, typ) 62 | } 63 | -------------------------------------------------------------------------------- /pkg/scanfunc/noop_scanner.go: -------------------------------------------------------------------------------- 1 | package scanfunc 2 | 3 | // Placeholder for columns which have no corresponding field in the 4 | // target struct. 5 | type NoOpScanner struct{} 6 | 7 | func (s *NoOpScanner) Scan(src interface{}) error { 8 | return nil 9 | } 10 | -------------------------------------------------------------------------------- /pkg/schema/db_column_info.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | type DBColumnInfo struct { 4 | TableName string 5 | Index int 6 | ColumnName string 7 | DataType string 8 | IsPrimaryKey bool 9 | } 10 | -------------------------------------------------------------------------------- /pkg/schema/db_column_meta.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import "database/sql" 4 | 5 | type DBColumnMeta struct { 6 | Name string 7 | Type string 8 | } 9 | 10 | func DBColumnMetaFromRows(rows *sql.Rows) ([]DBColumnMeta, error) { 11 | typs, err := rows.ColumnTypes() 12 | if err != nil { 13 | return nil, err 14 | } 15 | res := make([]DBColumnMeta, 0, len(typs)) 16 | for _, typ := range typs { 17 | res = append(res, DBColumnMeta{ 18 | Name: typ.Name(), 19 | Type: typ.DatabaseTypeName(), 20 | }) 21 | } 22 | return res, nil 23 | } 24 | -------------------------------------------------------------------------------- /pkg/schema/dml.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "sync/atomic" 5 | "time" 6 | ) 7 | 8 | // These are the markers used to indicate the start and end of transactions 9 | // in ctldb's DML log. 10 | const DMLTxBeginKey = "--- BEGIN" 11 | const DMLTxEndKey = "--- COMMIT" 12 | 13 | var currentTestDmlSeq int64 14 | 15 | type DMLSequence int64 16 | 17 | type DMLStatement struct { 18 | Sequence DMLSequence 19 | Timestamp time.Time 20 | Statement string 21 | } 22 | 23 | func (seq DMLSequence) Int() int64 { 24 | return int64(seq) 25 | } 26 | 27 | // used for testing 28 | func NewTestDMLStatement(statement string) DMLStatement { 29 | return DMLStatement{ 30 | Statement: statement, 31 | Sequence: nextTestDmlSeq(), 32 | Timestamp: time.Now(), 33 | } 34 | } 35 | 36 | func nextTestDmlSeq() DMLSequence { 37 | return DMLSequence(atomic.AddInt64(¤tTestDmlSeq, 1)) 38 | } 39 | -------------------------------------------------------------------------------- /pkg/schema/family_name.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "regexp" 7 | "strings" 8 | ) 9 | 10 | type FamilyName struct { 11 | Name string 12 | } 13 | 14 | const ( 15 | MinFamilyNameLength = 3 16 | MaxFamilyNameLength = 30 17 | ) 18 | 19 | var ( 20 | ErrFamilyNameInvalid = errors.New("Family names must be only letters, numbers, and single underscore") 21 | ErrFamilyNameTooLong = fmt.Errorf("Family names can only be up to %d characters", MaxFamilyNameLength) 22 | ErrFamilyNameTooShort = fmt.Errorf("Family names must be at least %d characters", MinFamilyNameLength) 23 | ) 24 | 25 | var familyNameChars = regexp.MustCompile("^$|^[a-z][a-z0-9_]*$") 26 | 27 | func NewFamilyName(name string) (FamilyName, error) { 28 | normalized, err := normalizeFamilyName(name) 29 | if err != nil { 30 | return FamilyName{}, err 31 | } 32 | return FamilyName{normalized}, nil 33 | } 34 | 35 | func (fn FamilyName) String() string { 36 | return fn.Name 37 | } 38 | 39 | func normalizeFamilyName(familyName string) (string, error) { 40 | lowered := strings.ToLower(familyName) 41 | if strings.Contains(lowered, "__") { 42 | return "", ErrFamilyNameInvalid 43 | } 44 | if !familyNameChars.MatchString(lowered) { 45 | return "", ErrFamilyNameInvalid 46 | } 47 | if len(lowered) > MaxFamilyNameLength { 48 | return "", ErrFamilyNameTooLong 49 | } 50 | if len(lowered) < MinFamilyNameLength { 51 | return "", ErrFamilyNameTooShort 52 | } 53 | return lowered, nil 54 | } 55 | -------------------------------------------------------------------------------- /pkg/schema/family_table.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/segmentio/stats/v4" 7 | ) 8 | 9 | // FamilyTable composes a family name and a table name 10 | type FamilyTable struct { 11 | Family string `json:"family"` 12 | Table string `json:"table"` 13 | } 14 | 15 | // String is a Stringer implementation that produces the fully qualified table name 16 | func (ft FamilyTable) String() string { 17 | return strings.Join([]string{ft.Family, ft.Table}, ldbTableNameDelimiter) 18 | } 19 | 20 | // Tag produces a stats tag that can be used to represent this table 21 | func (ft FamilyTable) Tag() stats.Tag { 22 | return stats.Tag{Name: "table", Value: ft.String()} 23 | } 24 | 25 | // parseFamilyTable breaks up a full table name into family/table parts. 26 | func ParseFamilyTable(fullName string) (ft FamilyTable, ok bool) { 27 | parts := strings.Split(fullName, ldbTableNameDelimiter) 28 | if len(parts) != 2 { 29 | return ft, false 30 | } 31 | ft.Family = parts[0] 32 | ft.Table = parts[1] 33 | return ft, true 34 | } 35 | -------------------------------------------------------------------------------- /pkg/schema/field_name.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "regexp" 7 | "strings" 8 | ) 9 | 10 | const ( 11 | MinFieldNameLength = 1 12 | MaxFieldNameLength = 64 13 | ) 14 | 15 | var fieldNameChars = regexp.MustCompile("^$|^[a-z][a-z0-9_]*$") 16 | 17 | var ( 18 | ErrFieldNameInvalid = errors.New("Field names must be only letters, numbers, and underscore") 19 | ErrFieldNameTooLong = fmt.Errorf("Field names can only be up to %d characters", MaxFieldNameLength) 20 | ErrFieldNameTooShort = fmt.Errorf("Field names must be at least %d characters", MinFieldNameLength) 21 | ) 22 | 23 | type FieldName struct { 24 | Name string 25 | } 26 | 27 | func (f FieldName) String() string { 28 | return f.Name 29 | } 30 | 31 | func NewFieldName(name string) (FieldName, error) { 32 | normalized, err := normalizeFieldName(name) 33 | if err != nil { 34 | return FieldName{}, err 35 | } 36 | return FieldName{Name: normalized}, nil 37 | } 38 | 39 | func StringifyFieldNames(fns []FieldName) []string { 40 | out := make([]string, len(fns)) 41 | for i, fn := range fns { 42 | out[i] = fn.Name 43 | } 44 | return out 45 | } 46 | 47 | func normalizeFieldName(fieldName string) (string, error) { 48 | lowered := strings.ToLower(fieldName) 49 | if !fieldNameChars.MatchString(lowered) { 50 | return "", ErrFieldNameInvalid 51 | } 52 | if len(lowered) > MaxFieldNameLength { 53 | return "", ErrFieldNameTooLong 54 | } 55 | if len(lowered) < MinFieldNameLength { 56 | return "", ErrFieldNameTooShort 57 | } 58 | return lowered, nil 59 | } 60 | -------------------------------------------------------------------------------- /pkg/schema/field_type.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import "strings" 4 | 5 | type FieldType int 6 | 7 | // CanBeKey returns if the field type can be used in a PK 8 | func (ft FieldType) CanBeKey() bool { 9 | return ft == FTString || ft == FTInteger || ft == FTByteString 10 | } 11 | 12 | func (ft FieldType) String() string { 13 | if s, ok := FieldTypeStringsByFieldType[ft]; ok { 14 | return s 15 | } 16 | 17 | return "unknown" 18 | } 19 | 20 | const ( 21 | _ FieldType = iota 22 | FTString 23 | FTInteger 24 | FTDecimal 25 | FTText 26 | FTBinary 27 | FTByteString 28 | ) 29 | 30 | // Maps FieldTypes to their stringly typed version 31 | var FieldTypeStringsByFieldType = map[FieldType]string{ 32 | FTString: "string", 33 | FTInteger: "integer", 34 | FTDecimal: "decimal", 35 | FTText: "text", 36 | FTBinary: "binary", 37 | FTByteString: "bytestring", 38 | } 39 | 40 | // Used for converting SQL-ized field types to FieldTypes 41 | var _sqlTypesToFieldTypes = map[string]FieldType{ 42 | "varchar": FTString, 43 | "varchar(191)": FTString, 44 | "char": FTString, 45 | "character": FTString, 46 | 47 | "text": FTText, 48 | "mediumtext": FTText, 49 | "longtext": FTText, 50 | 51 | "integer": FTInteger, 52 | "smallint": FTInteger, 53 | "mediumint": FTInteger, 54 | "bigint": FTInteger, 55 | 56 | "real": FTDecimal, 57 | "float": FTDecimal, 58 | "double": FTDecimal, 59 | 60 | "blob": FTBinary, 61 | "mediumblob": FTBinary, 62 | "longblob": FTBinary, 63 | 64 | "varbinary": FTByteString, 65 | "blob(255)": FTByteString, 66 | } 67 | 68 | // Convert a known SQL type string to a FieldType 69 | func SqlTypeToFieldType(sqlType string) (FieldType, bool) { 70 | // TODO: write a test that resolves all known generated types against this one 71 | loweredType := strings.ToLower(sqlType) 72 | ft, ok := _sqlTypesToFieldTypes[loweredType] 73 | return ft, ok 74 | } 75 | 76 | // Returns a map of stringly typed field types to strongly typed field types 77 | func FieldTypeMap() map[string]FieldType { 78 | ftm := map[string]FieldType{} 79 | for ft, str := range FieldTypeStringsByFieldType { 80 | ftm[str] = ft 81 | } 82 | return ftm 83 | } 84 | -------------------------------------------------------------------------------- /pkg/schema/ldb_table_name.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | ) 7 | 8 | const ( 9 | ldbTableNameDelimiter = "___" 10 | ) 11 | 12 | // Converts a family/table name pair to a concatenated version that works 13 | // with SQLite, which doesn't support SQL schema objects. Use ___ to avoid 14 | // easy accidental hijacks. 15 | func LDBTableName(famName FamilyName, tblName TableName) string { 16 | return strings.Join( 17 | []string{famName.Name, tblName.Name}, 18 | ldbTableNameDelimiter) 19 | } 20 | 21 | // The opposite of ldbTableName() 22 | func DecodeLDBTableName(tableName string) (fn FamilyName, tn TableName, err error) { 23 | splitted := strings.Split(tableName, ldbTableNameDelimiter) 24 | if len(splitted) != 2 { 25 | err = errors.New("decodeLdbTableName couldn't split string properly") 26 | return 27 | } 28 | 29 | fn, err = NewFamilyName(splitted[0]) 30 | if err != nil { 31 | return 32 | } 33 | 34 | tn, err = NewTableName(splitted[1]) 35 | if err != nil { 36 | return 37 | } 38 | 39 | return 40 | } 41 | -------------------------------------------------------------------------------- /pkg/schema/named_field_type.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | // Roses are red, 4 | // Violets are blue, 5 | // This type would fill me with less existential dread, 6 | // If Go had a tuple type instead 7 | type NamedFieldType struct { 8 | Name FieldName 9 | FieldType FieldType 10 | } 11 | -------------------------------------------------------------------------------- /pkg/schema/params.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/segmentio/ctlstore/pkg/errs" 7 | ) 8 | 9 | func UnzipFieldsParam(fields [][]string) (fieldNames []string, fieldTypes []FieldType, err error) { 10 | fieldNames = []string{} 11 | fieldTypes = []FieldType{} 12 | typeMap := FieldTypeMap() 13 | 14 | for idx, fieldTuple := range fields { 15 | if want, got := 2, len(fieldTuple); want != got { 16 | err = &errs.BadRequestError{Err: fmt.Sprintf("Field #%d is malformed: expected %d elements, got %d", idx, want, got)} 17 | return 18 | } 19 | 20 | rawType := fieldTuple[1] 21 | mappedType, ok := typeMap[rawType] 22 | if !ok { 23 | err = &errs.BadRequestError{Err: fmt.Sprintf("Field #%d: Type '%s' unknown", idx, rawType)} 24 | return 25 | } 26 | 27 | fieldNames = append(fieldNames, fieldTuple[0]) 28 | fieldTypes = append(fieldTypes, mappedType) 29 | } 30 | return 31 | } 32 | -------------------------------------------------------------------------------- /pkg/schema/primary_key.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import errors "github.com/segmentio/errors-go" 4 | 5 | var PrimaryKeyZero = PrimaryKey{} 6 | 7 | type PrimaryKey struct { 8 | Fields []FieldName 9 | Types []FieldType 10 | } 11 | 12 | // Is this a zero value? 13 | func (pk *PrimaryKey) Zero() bool { 14 | return len(pk.Fields) == 0 15 | } 16 | 17 | // Returns the list of fields as strings 18 | func (pk *PrimaryKey) Strings() []string { 19 | out := make([]string, len(pk.Fields)) 20 | for i, fn := range pk.Fields { 21 | out[i] = fn.Name 22 | } 23 | return out 24 | } 25 | 26 | // builds a new primary key from a slice of field names and the corresponding field types. 27 | func NewPKFromRawNamesAndFieldTypes(names []string, types []FieldType) (PrimaryKey, error) { 28 | fns := make([]FieldName, len(names)) 29 | for i, name := range names { 30 | fn, err := NewFieldName(name) 31 | if err != nil { 32 | return PrimaryKeyZero, err 33 | } 34 | fns[i] = fn 35 | } 36 | return PrimaryKey{Fields: fns, Types: types}, nil 37 | } 38 | 39 | // builds a new primary key from a slice of field name and the string representation of 40 | // the field types. The field types in this case should have entries in the 41 | // _sqlTypesToFieldTypes map. A failure to map from a field type string to a FieldType 42 | // will result in an error. 43 | func NewPKFromRawNamesAndTypes(names []string, types []string) (PrimaryKey, error) { 44 | fns := make([]FieldName, len(names)) 45 | fts := make([]FieldType, len(names)) 46 | for i, name := range names { 47 | fn, err := NewFieldName(name) 48 | if err != nil { 49 | return PrimaryKeyZero, err 50 | } 51 | ft, ok := SqlTypeToFieldType(types[i]) 52 | if !ok { 53 | return PrimaryKeyZero, errors.Errorf("no field type found for '%s'", types[i]) 54 | } 55 | fns[i] = fn 56 | fts[i] = ft 57 | } 58 | return PrimaryKey{Fields: fns, Types: fts}, nil 59 | } 60 | -------------------------------------------------------------------------------- /pkg/schema/table.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | type Table struct { 4 | Family string `json:"family"` 5 | Name string `json:"name"` 6 | Fields [][]string `json:"fields"` 7 | KeyFields []string `json:"keyFields"` 8 | } 9 | -------------------------------------------------------------------------------- /pkg/schema/table_name.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "regexp" 7 | "strings" 8 | ) 9 | 10 | const ( 11 | MinTableNameLength = 3 12 | MaxTableNameLength = 50 13 | ) 14 | 15 | // use newTableName to construct a tableName 16 | type TableName struct { 17 | Name string 18 | } 19 | 20 | var TableNameZero = TableName{} 21 | 22 | var tableNameChars = regexp.MustCompile("^$|^[a-z][a-z0-9_]*$") 23 | 24 | var ( 25 | ErrTableNameInvalid = errors.New("Table names must be only letters, numbers, and single underscore") 26 | ErrTableNameTooLong = fmt.Errorf("Table names can only be up to %d characters", MaxTableNameLength) 27 | ErrTableNameTooShort = fmt.Errorf("Table names must be at least %d characters", MinTableNameLength) 28 | ) 29 | 30 | func NewTableName(name string) (TableName, error) { 31 | normalized, err := normalizeTableName(name) 32 | if err != nil { 33 | return TableNameZero, err 34 | } 35 | return TableName{normalized}, nil 36 | } 37 | 38 | func (tn TableName) String() string { 39 | return tn.Name 40 | } 41 | 42 | func normalizeTableName(tableName string) (string, error) { 43 | lowered := strings.ToLower(tableName) 44 | if strings.Contains(lowered, "__") { 45 | return "", ErrTableNameInvalid 46 | } 47 | if !tableNameChars.MatchString(lowered) { 48 | return "", ErrTableNameInvalid 49 | } 50 | if len(lowered) > MaxTableNameLength { 51 | return "", ErrTableNameTooLong 52 | } 53 | if len(lowered) < MinTableNameLength { 54 | return "", ErrTableNameTooShort 55 | } 56 | return lowered, nil 57 | } 58 | -------------------------------------------------------------------------------- /pkg/schema/validate_test.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | func TestNormalizeFamilyName(t *testing.T) { 10 | suite := []struct { 11 | desc string 12 | input string 13 | expectStr string 14 | expectErr error 15 | }{ 16 | {"Lowers", "LOWER", "lower", nil}, 17 | {"Too short", "ab", "", ErrFamilyNameTooShort}, 18 | {"Too long", strings.Repeat("a", 31), "", ErrFamilyNameTooLong}, 19 | {"Invalid chars", "abc-123", "", ErrFamilyNameInvalid}, 20 | {"Starts with number", "1abc", "", ErrFamilyNameInvalid}, 21 | {"Contains multi-underscore", "a__b", "", ErrFamilyNameInvalid}, 22 | } 23 | 24 | for i, testCase := range suite { 25 | testName := fmt.Sprintf("%d %s", i, testCase.desc) 26 | t.Run(testName, func(t *testing.T) { 27 | gotStr, gotErr := normalizeFamilyName(testCase.input) 28 | if want, got := testCase.expectErr, gotErr; want != got { 29 | t.Errorf("Expected error %v, got %v", want, got) 30 | } 31 | if want, got := testCase.expectStr, gotStr; want != got { 32 | t.Errorf("Expected %v, got %v", want, got) 33 | } 34 | }) 35 | } 36 | } 37 | 38 | func TestNormalizeTableName(t *testing.T) { 39 | suite := []struct { 40 | desc string 41 | input string 42 | expectStr string 43 | expectErr error 44 | }{ 45 | {"Lowers", "LOWER", "lower", nil}, 46 | {"Too short", "ab", "", ErrTableNameTooShort}, 47 | {"Too long", strings.Repeat("a", 51), "", ErrTableNameTooLong}, 48 | {"Invalid chars", "abc-123", "", ErrTableNameInvalid}, 49 | {"Starts with number", "1abc", "", ErrTableNameInvalid}, 50 | {"Contains multi-underscore", "a__b", "", ErrTableNameInvalid}, 51 | } 52 | 53 | for i, testCase := range suite { 54 | testName := fmt.Sprintf("%d %s", i, testCase.desc) 55 | t.Run(testName, func(t *testing.T) { 56 | gotStr, gotErr := normalizeTableName(testCase.input) 57 | if want, got := testCase.expectErr, gotErr; want != got { 58 | t.Errorf("Expected error %v, got %v", want, got) 59 | } 60 | if want, got := testCase.expectStr, gotStr; want != got { 61 | t.Errorf("Expected %v, got %v", want, got) 62 | } 63 | }) 64 | } 65 | } 66 | 67 | func TestNormalizeFieldName(t *testing.T) { 68 | suite := []struct { 69 | desc string 70 | input string 71 | expectStr string 72 | expectErr error 73 | }{ 74 | {"Lowers", "LOWER", "lower", nil}, 75 | {"Too short", "", "", ErrFieldNameTooShort}, 76 | {"Too long", strings.Repeat("a", 100), "", ErrFieldNameTooLong}, 77 | {"Invalid chars", "abc-123", "", ErrFieldNameInvalid}, 78 | {"Starts with number", "1abc", "", ErrFieldNameInvalid}, 79 | } 80 | 81 | for i, testCase := range suite { 82 | testName := fmt.Sprintf("%d %s", i, testCase.desc) 83 | t.Run(testName, func(t *testing.T) { 84 | gotStr, gotErr := normalizeFieldName(testCase.input) 85 | if want, got := testCase.expectErr, gotErr; want != got { 86 | t.Errorf("Expected error %v, got %v", want, got) 87 | } 88 | if want, got := testCase.expectStr, gotStr; want != got { 89 | t.Errorf("Expected %v, got %v", want, got) 90 | } 91 | }) 92 | } 93 | } 94 | 95 | func TestValidateWriterName(t *testing.T) { 96 | suite := []struct { 97 | desc string 98 | input string 99 | expectStr string 100 | expectErr error 101 | }{ 102 | {"Too short", "x", "", ErrWriterNameTooShort}, 103 | {"Too long", strings.Repeat("a", 51), "", ErrWriterNameTooLong}, 104 | } 105 | 106 | for i, testCase := range suite { 107 | testName := fmt.Sprintf("%d %s", i, testCase.desc) 108 | t.Run(testName, func(t *testing.T) { 109 | gotStr, gotErr := validateWriterName(testCase.input) 110 | if want, got := testCase.expectErr, gotErr; want != got { 111 | t.Errorf("Expected error %v, got %v", want, got) 112 | } 113 | if want, got := testCase.expectStr, gotStr; want != got { 114 | t.Errorf("Expected %v, got %v", want, got) 115 | } 116 | }) 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /pkg/schema/writer_name.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import "fmt" 4 | 5 | // use newWriterName to construct a writerName 6 | type WriterName struct { 7 | Name string 8 | } 9 | 10 | const ( 11 | MinWriterNameLength = 3 12 | MaxWriterNameLength = 50 13 | ) 14 | 15 | var ( 16 | ErrWriterNameTooLong = fmt.Errorf("Writer names can only be up to %d characters", MaxWriterNameLength) 17 | ErrWriterNameTooShort = fmt.Errorf("Writer names must be at least %d characters", MinWriterNameLength) 18 | ) 19 | 20 | func validateWriterName(writerName string) (string, error) { 21 | if len(writerName) > MaxWriterNameLength { 22 | return "", ErrWriterNameTooLong 23 | } 24 | if len(writerName) < MinWriterNameLength { 25 | return "", ErrWriterNameTooShort 26 | } 27 | return writerName, nil 28 | } 29 | 30 | func NewWriterName(name string) (WriterName, error) { 31 | validated, err := validateWriterName(name) 32 | if err != nil { 33 | return WriterName{}, err 34 | } 35 | return WriterName{validated}, nil 36 | } 37 | 38 | func (wn WriterName) String() string { 39 | return wn.Name 40 | } 41 | -------------------------------------------------------------------------------- /pkg/sqlite/driver.go: -------------------------------------------------------------------------------- 1 | package sqlite 2 | 3 | import ( 4 | "database/sql" 5 | "sync" 6 | 7 | "github.com/segmentio/go-sqlite3" 8 | _ "github.com/segmentio/go-sqlite3" 9 | ) 10 | 11 | func init() { 12 | InitDriver() 13 | } 14 | 15 | var initDriverOnce sync.Once 16 | 17 | // InitDriver ensures that the sqlite3 driver is initialized 18 | func InitDriver() { 19 | initDriverOnce.Do(func() { 20 | sql.Register("sqlite3_with_autocheckpoint_off", &sqlite3.SQLiteDriver{ 21 | ConnectHook: func(conn *sqlite3.SQLiteConn) error { 22 | // This turns off automatic WAL checkpoints in the reader. Since the reader 23 | // can't do checkpoints as it's usually in read-only mode, checkpoints only 24 | // result in an error getting returned to callers in some circumstances. 25 | // As the Reflector is the only writer to the LDB, and it will continue to 26 | // run checkpoints, the WAL will stay nice and tidy. 27 | _, err := conn.Exec("PRAGMA wal_autocheckpoint = 0", nil) 28 | return err 29 | }, 30 | }) 31 | }) 32 | } 33 | -------------------------------------------------------------------------------- /pkg/sqlite/sql_change_buffer.go: -------------------------------------------------------------------------------- 1 | package sqlite 2 | 3 | import "sync" 4 | 5 | // SQLChangeBuffer accumulates sqliteWatchChanges and allows them to be popped 6 | // off later when writing the changelog. 7 | type SQLChangeBuffer struct { 8 | mut sync.Mutex 9 | changes []SQLiteWatchChange 10 | } 11 | 12 | // add appends a change to the end of the buffer 13 | func (b *SQLChangeBuffer) Add(change SQLiteWatchChange) { 14 | b.mut.Lock() 15 | defer b.mut.Unlock() 16 | b.changes = append(b.changes, change) 17 | } 18 | 19 | // pop returns the accumulated changes and then resets the buffer 20 | func (b *SQLChangeBuffer) Pop() []SQLiteWatchChange { 21 | b.mut.Lock() 22 | defer b.mut.Unlock() 23 | res := b.changes 24 | b.changes = nil 25 | return res 26 | } 27 | -------------------------------------------------------------------------------- /pkg/sqlite/sql_change_buffer_test.go: -------------------------------------------------------------------------------- 1 | package sqlite 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestChangeBuffer(t *testing.T) { 10 | var buf SQLChangeBuffer 11 | assert.Len(t, buf.Pop(), 0) 12 | 13 | buf.Add(SQLiteWatchChange{ 14 | DatabaseName: "t1", 15 | }) 16 | buf.Add(SQLiteWatchChange{ 17 | DatabaseName: "t2", 18 | }) 19 | pop := buf.Pop() 20 | assert.EqualValues(t, []SQLiteWatchChange{ 21 | {DatabaseName: "t1"}, 22 | {DatabaseName: "t2"}, 23 | }, pop) 24 | 25 | // verify there are no more changes 26 | assert.Len(t, buf.Pop(), 0) 27 | 28 | } 29 | -------------------------------------------------------------------------------- /pkg/sqlite/sqlite_info.go: -------------------------------------------------------------------------------- 1 | package sqlite 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | 8 | "github.com/pkg/errors" 9 | "github.com/segmentio/ctlstore/pkg/schema" 10 | "github.com/segmentio/ctlstore/pkg/sqlgen" 11 | ) 12 | 13 | type SqliteDBInfo struct { 14 | Db *sql.DB 15 | } 16 | 17 | func (m *SqliteDBInfo) GetAllTables(ctx context.Context) ([]schema.FamilyTable, error) { 18 | var res []schema.FamilyTable 19 | rows, err := m.Db.QueryContext(ctx, "select distinct name from sqlite_master where type='table' order by name") 20 | if err != nil { 21 | return nil, errors.Wrap(err, "query table names") 22 | } 23 | for rows.Next() { 24 | var fullName string 25 | err = rows.Scan(&fullName) 26 | if err != nil { 27 | return nil, errors.Wrap(err, "scan table name") 28 | } 29 | if ft, ok := schema.ParseFamilyTable(fullName); ok { 30 | res = append(res, ft) 31 | } 32 | 33 | } 34 | return res, err 35 | } 36 | 37 | func (m *SqliteDBInfo) GetColumnInfo(ctx context.Context, tableNames []string) ([]schema.DBColumnInfo, error) { 38 | if len(tableNames) == 0 { 39 | return []schema.DBColumnInfo{}, nil 40 | } 41 | columnInfos := []schema.DBColumnInfo{} 42 | for _, tableName := range tableNames { 43 | err := func() error { 44 | qTableName, err := sqlgen.SQLQuote(tableName) 45 | if err != nil { 46 | return err 47 | } 48 | 49 | qs := fmt.Sprintf( 50 | "SELECT cid, name, type, pk FROM pragma_table_info(%s) "+ 51 | "ORDER BY cid ASC", 52 | qTableName) 53 | 54 | rows, err := m.Db.QueryContext(ctx, qs) 55 | if err != nil { 56 | return err 57 | } 58 | defer rows.Close() 59 | 60 | for rows.Next() { 61 | var colID int 62 | var colName string 63 | var dataType string 64 | var pk int 65 | 66 | err = rows.Scan(&colID, &colName, &dataType, &pk) 67 | if err != nil { 68 | return err 69 | } 70 | 71 | columnInfos = append(columnInfos, schema.DBColumnInfo{ 72 | TableName: tableName, 73 | Index: colID, 74 | ColumnName: colName, 75 | DataType: dataType, 76 | IsPrimaryKey: (pk > 0), 77 | }) 78 | } 79 | return rows.Err() 80 | }() 81 | if err != nil { 82 | return nil, err 83 | } 84 | } 85 | return columnInfos, nil 86 | } 87 | -------------------------------------------------------------------------------- /pkg/sqlite/sqlite_watch.go: -------------------------------------------------------------------------------- 1 | package sqlite 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | 7 | "github.com/pkg/errors" 8 | "github.com/segmentio/ctlstore/pkg/scanfunc" 9 | "github.com/segmentio/ctlstore/pkg/schema" 10 | "github.com/segmentio/go-sqlite3" 11 | ) 12 | 13 | type ( 14 | SQLiteWatchChange struct { 15 | Op int 16 | DatabaseName string 17 | TableName string 18 | OldRowID int64 19 | NewRowID int64 20 | OldRow []interface{} 21 | NewRow []interface{} 22 | } 23 | // pkAndMeta is a primary key value with name and type metadata to boot 24 | pkAndMeta struct { 25 | Name string `json:"name"` 26 | Type string `json:"type"` 27 | Value interface{} `json:"value"` 28 | } 29 | ) 30 | 31 | // Registers a hook against dbName that will populate the passed buffer with 32 | // sqliteWatchChange messages each time a change is executed against the 33 | // database. These messages are pre-update, so the buffer will be populated 34 | // before the change is committed. 35 | func RegisterSQLiteWatch(dbName string, buffer *SQLChangeBuffer) error { 36 | sql.Register(dbName, &sqlite3.SQLiteDriver{ 37 | ConnectHook: func(conn *sqlite3.SQLiteConn) error { 38 | conn.RegisterPreUpdateHook(func(pud sqlite3.SQLitePreUpdateData) { 39 | cnt := pud.Count() 40 | var newRow []interface{} 41 | var oldRow []interface{} 42 | 43 | if pud.Op == sqlite3.SQLITE_UPDATE || pud.Op == sqlite3.SQLITE_DELETE { 44 | oldRow = make([]interface{}, cnt) 45 | err := pud.Old(oldRow...) 46 | if err != nil { 47 | return 48 | } 49 | } 50 | 51 | if pud.Op == sqlite3.SQLITE_UPDATE || pud.Op == sqlite3.SQLITE_INSERT { 52 | newRow = make([]interface{}, cnt) 53 | err := pud.New(newRow...) 54 | if err != nil { 55 | return 56 | } 57 | } 58 | 59 | buffer.Add(SQLiteWatchChange{ 60 | Op: pud.Op, 61 | DatabaseName: pud.DatabaseName, 62 | TableName: pud.TableName, 63 | OldRowID: pud.OldRowID, 64 | NewRowID: pud.NewRowID, 65 | OldRow: oldRow, 66 | NewRow: newRow, 67 | }) 68 | }) 69 | return nil 70 | }, 71 | }) 72 | 73 | return nil 74 | } 75 | 76 | // Returns the primary key values of the impacted rows by looking up the 77 | // metadata in the passed db. 78 | func (c *SQLiteWatchChange) ExtractKeys(db *sql.DB) ([][]interface{}, error) { 79 | // guard this edge just in case! 80 | if c.DatabaseName != "main" { 81 | return nil, errors.New("Only meant to be used on main database") 82 | } 83 | 84 | // go straight for the sqlite db info instead of going through the dbinfo 85 | // package, which lets us avoid importing a mysql dependency. 86 | dbInfo := SqliteDBInfo{Db: db} 87 | colInfos, err := dbInfo.GetColumnInfo(context.Background(), []string{c.TableName}) 88 | if err != nil { 89 | return nil, err 90 | } 91 | 92 | exKey := func(row []interface{}) ([]interface{}, error) { 93 | key := []interface{}{} 94 | for _, colInfo := range colInfos { 95 | if colInfo.IsPrimaryKey { 96 | if colInfo.Index >= len(row) { 97 | // Should never happen, but yeah. 98 | return nil, errors.New("column info couldn't be matched to row") 99 | } 100 | // use a placeholder to scan the value of the column. it will use the 101 | // column metadata to correctly convert byte slices into strings 102 | // where appropriate. 103 | ph := scanfunc.Placeholder{ 104 | Col: schema.DBColumnMeta{ 105 | Name: colInfo.ColumnName, 106 | Type: colInfo.DataType, 107 | }, 108 | } 109 | if err := ph.Scan(row[colInfo.Index]); err != nil { 110 | return nil, errors.Wrap(err, "scan key value column") 111 | } 112 | key = append(key, pkAndMeta{ 113 | Name: colInfo.ColumnName, 114 | Type: colInfo.DataType, 115 | Value: ph.Val, 116 | }) 117 | } 118 | } 119 | return key, nil 120 | } 121 | 122 | keys := [][]interface{}{} 123 | for _, row := range [][]interface{}{c.OldRow, c.NewRow} { 124 | if row != nil { 125 | key, err := exKey(row) 126 | if err != nil { 127 | return nil, err 128 | } 129 | if len(key) > 0 { 130 | keys = append(keys, key) 131 | } 132 | } 133 | } 134 | return keys, nil 135 | } 136 | -------------------------------------------------------------------------------- /pkg/supervisor/fake_read_closer.go: -------------------------------------------------------------------------------- 1 | package supervisor 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | 7 | "github.com/segmentio/ctlstore/pkg/utils" 8 | ) 9 | 10 | type fakeReadCloser struct { 11 | rc io.ReadCloser 12 | readErr error 13 | readCalled utils.AtomicBool 14 | closeErr error 15 | closeCalled utils.AtomicBool 16 | } 17 | 18 | func (r *fakeReadCloser) Read(p []byte) (n int, err error) { 19 | r.readCalled.SetTrue() 20 | n, err = r.rc.Read(p) 21 | if r.readErr != nil { 22 | err = r.readErr 23 | } 24 | return n, err 25 | } 26 | 27 | func (r *fakeReadCloser) Close() error { 28 | r.closeCalled.SetTrue() 29 | err := r.rc.Close() 30 | if r.closeErr != nil { 31 | fmt.Println("Returning", r.closeErr) 32 | err = r.closeErr 33 | } 34 | return err 35 | } 36 | -------------------------------------------------------------------------------- /pkg/supervisor/generate.go: -------------------------------------------------------------------------------- 1 | package supervisor 2 | 3 | //go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate 4 | -------------------------------------------------------------------------------- /pkg/supervisor/gzip_pipe.go: -------------------------------------------------------------------------------- 1 | package supervisor 2 | 3 | import ( 4 | "compress/gzip" 5 | "io" 6 | "sync" 7 | 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | type gzipCompressionReader struct { 12 | reader io.Reader // the original reader 13 | pipeReader *io.PipeReader // what we'll actually read from 14 | bytesRead int // how many gzip bytes were transferred 15 | once sync.Once 16 | } 17 | 18 | var _ io.Reader = (*gzipCompressionReader)(nil) 19 | 20 | // newGZIPPipeReader provides a reader that reads a delegate reader's 21 | // bytes but compresses them as GZIP. It does this by using io.Pipe() 22 | // and a gzip writer that writes to the *PipeWriter. The read end of 23 | // the pipe is what is used to satisfy the io.Reader contract. 24 | func newGZIPCompressionReader(reader io.Reader) *gzipCompressionReader { 25 | return &gzipCompressionReader{ 26 | reader: reader, 27 | } 28 | } 29 | 30 | func (r *gzipCompressionReader) Read(p []byte) (n int, err error) { 31 | if r.reader == nil { 32 | return -1, errors.New("no reader specified") 33 | } 34 | r.once.Do(func() { 35 | var pw *io.PipeWriter 36 | r.pipeReader, pw = io.Pipe() 37 | gw := gzip.NewWriter(pw) 38 | go func() { 39 | pw.CloseWithError(func() error { 40 | _, err := io.Copy(gw, r.reader) 41 | if err != nil { 42 | return errors.Wrap(err, "copy to gzip writer") 43 | } 44 | if err = gw.Close(); err != nil { 45 | return errors.Wrap(err, "close gzip writer") 46 | } 47 | return nil 48 | }()) 49 | }() 50 | }) 51 | n, err = r.pipeReader.Read(p) 52 | if n > 0 { 53 | r.bytesRead += n 54 | } 55 | return n, err 56 | } 57 | -------------------------------------------------------------------------------- /pkg/supervisor/gzip_pipe_test.go: -------------------------------------------------------------------------------- 1 | package supervisor 2 | 3 | import ( 4 | "bytes" 5 | "compress/gzip" 6 | "io" 7 | "io/ioutil" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/pkg/errors" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestGZIPPipeReader(t *testing.T) { 16 | input := "hello world" 17 | var reader io.Reader = strings.NewReader(input) 18 | reader = newGZIPCompressionReader(reader) 19 | deflated, err := ioutil.ReadAll(reader) 20 | require.NoError(t, err) 21 | 22 | reader, err = gzip.NewReader(bytes.NewReader(deflated)) 23 | require.NoError(t, err) 24 | inflated, err := ioutil.ReadAll(reader) 25 | require.NoError(t, err) 26 | require.EqualValues(t, input, string(inflated)) 27 | } 28 | 29 | func TestGZIPPipeReaderErr(t *testing.T) { 30 | for _, test := range []struct { 31 | name string 32 | input io.ReadCloser 33 | readErr error 34 | closeErr error 35 | expected error 36 | }{ 37 | { 38 | name: "no err", 39 | input: ioutil.NopCloser(strings.NewReader("hello, world")), 40 | }, 41 | { 42 | name: "read err", 43 | input: ioutil.NopCloser(strings.NewReader("hello, world")), 44 | readErr: errors.New("read failed"), 45 | expected: errors.New("copy to gzip writer: read failed"), 46 | }, 47 | { 48 | name: "close err", 49 | input: ioutil.NopCloser(strings.NewReader("hello, world")), 50 | closeErr: errors.New("close failed"), 51 | expected: nil, // the gzip pipe reader should not close the input reader 52 | }, 53 | } { 54 | t.Run(test.name, func(t *testing.T) { 55 | fake := &fakeReadCloser{ 56 | rc: test.input, 57 | readErr: test.readErr, 58 | closeErr: test.closeErr, 59 | } 60 | reader := newGZIPCompressionReader(fake) 61 | _, err := ioutil.ReadAll(reader) 62 | if test.expected == nil { 63 | require.NoError(t, err) 64 | } else { 65 | require.EqualError(t, err, test.expected.Error()) 66 | require.True(t, fake.readCalled.IsSet()) 67 | require.False(t, fake.closeCalled.IsSet()) 68 | } 69 | }) 70 | } 71 | } 72 | 73 | // TestIOPipes serves as a reference on how to use this damn thing. 74 | func TestIOPipes(t *testing.T) { 75 | const bufSize = 100 * 1024 76 | data := make([]byte, bufSize) 77 | 78 | var reader io.Reader 79 | 80 | // verify that the entire payload is read uncompressed 81 | 82 | reader = bytes.NewReader(data) 83 | deflated, err := ioutil.ReadAll(reader) 84 | require.NoError(t, err) 85 | require.Equal(t, bufSize, len(deflated)) 86 | 87 | // read the bytes as gzip 88 | 89 | pr, pw := io.Pipe() 90 | gw := gzip.NewWriter(pw) 91 | go func() { 92 | pw.CloseWithError(func() error { 93 | _, err := io.Copy(gw, bytes.NewReader(data)) 94 | if err != nil { 95 | return errors.Wrap(err, "copy to gw") 96 | } 97 | if err = gw.Close(); err != nil { 98 | return errors.Wrap(err, "close gzip writer") 99 | } 100 | return nil 101 | }()) 102 | }() 103 | 104 | deflated, err = ioutil.ReadAll(pr) 105 | require.NoError(t, err) 106 | require.True(t, len(deflated) < bufSize, "source=%d res=%d", bufSize, len(deflated)) 107 | 108 | reader, err = gzip.NewReader(bytes.NewReader(deflated)) 109 | require.NoError(t, err) 110 | inflated, err := ioutil.ReadAll(reader) 111 | require.NoError(t, err) 112 | require.EqualValues(t, data, inflated) 113 | } 114 | -------------------------------------------------------------------------------- /pkg/supervisor/s3_snapshot_test.go: -------------------------------------------------------------------------------- 1 | package supervisor 2 | 3 | import ( 4 | "bytes" 5 | "compress/gzip" 6 | "context" 7 | "io" 8 | "io/ioutil" 9 | "os" 10 | "strings" 11 | "testing" 12 | "time" 13 | 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | func TestS3SnapshotCompression(t *testing.T) { 18 | for _, test := range []struct { 19 | name string 20 | url string 21 | compression bool 22 | payload string 23 | expectBucket string 24 | expectKey string 25 | }{ 26 | { 27 | name: "no compression", 28 | url: "s3://segment-ctlstore-snapshots-stage/snapshot.db", 29 | compression: false, 30 | payload: "s3 payload content", 31 | expectBucket: "segment-ctlstore-snapshots-stage", 32 | expectKey: "snapshot.db", 33 | }, 34 | { 35 | name: "with compression", 36 | url: "s3://segment-ctlstore-snapshots-stage/snapshot.db.gz", 37 | compression: true, 38 | payload: "s3 payload content", 39 | expectBucket: "segment-ctlstore-snapshots-stage", 40 | expectKey: "snapshot.db.gz", 41 | }, 42 | } { 43 | t.Run(test.name, func(t *testing.T) { 44 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 45 | defer cancel() 46 | snapshot, err := archivedSnapshotFromURL(test.url) 47 | require.NoError(t, err) 48 | s3snap, ok := snapshot.(*s3Snapshot) 49 | require.True(t, ok) 50 | var sent struct { 51 | key string 52 | bucket string 53 | bytes []byte 54 | } 55 | s3snap.sendToS3Func = func(ctx context.Context, key string, bucket string, body io.Reader) (err error) { 56 | sent.key = key 57 | sent.bucket = bucket 58 | sent.bytes, err = ioutil.ReadAll(body) 59 | return 60 | } 61 | file, err := ioutil.TempFile("", test.name) 62 | require.NoError(t, err) 63 | defer os.Remove(file.Name()) 64 | _, err = io.Copy(file, strings.NewReader(test.payload)) 65 | require.NoError(t, err) 66 | err = file.Close() 67 | require.NoError(t, err) 68 | err = snapshot.Upload(ctx, file.Name()) 69 | require.NoError(t, err) 70 | require.Equal(t, test.expectKey, sent.key) 71 | require.Equal(t, test.expectBucket, sent.bucket) 72 | if test.compression { 73 | r, err := gzip.NewReader(bytes.NewReader(sent.bytes)) 74 | require.NoError(t, err) 75 | b, err := ioutil.ReadAll(r) 76 | require.NoError(t, err) 77 | require.Equal(t, test.payload, string(b)) 78 | } else { 79 | require.Equal(t, test.payload, string(sent.bytes)) 80 | } 81 | }) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /pkg/supervisor/s3_uploader.go: -------------------------------------------------------------------------------- 1 | package supervisor 2 | 3 | import ( 4 | "github.com/aws/aws-sdk-go-v2/feature/s3/manager" 5 | ) 6 | 7 | //counterfeiter:generate -o fakes/s3_uploader.go . S3Client 8 | type S3Client interface { 9 | manager.UploadAPIClient 10 | } 11 | -------------------------------------------------------------------------------- /pkg/tests/tests.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "io/ioutil" 7 | "os" 8 | "path/filepath" 9 | "testing" 10 | "time" 11 | 12 | _ "github.com/go-sql-driver/mysql" 13 | "github.com/segmentio/ctlstore/pkg/ctldb" 14 | "github.com/segmentio/ctlstore/pkg/utils" 15 | ) 16 | 17 | func WithTmpDir(t testing.TB) (dir string, teardown func()) { 18 | tmpDir, err := ioutil.TempDir("", "") 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | return tmpDir, func() { 23 | os.RemoveAll(tmpDir) 24 | } 25 | } 26 | 27 | func WithTmpFile(t testing.TB, name string) (file *os.File, teardown func()) { 28 | var teardowns utils.Teardowns 29 | dir, teardown := WithTmpDir(t) 30 | teardowns.Add(teardown) 31 | 32 | path := filepath.Join(dir, name) 33 | var err error 34 | file, err = os.Create(path) 35 | if err != nil { 36 | t.Fatal(err) 37 | } 38 | teardowns.Add(func() { file.Close() }) 39 | return file, teardowns.Teardown 40 | } 41 | 42 | func CheckCtldb(t *testing.T) { 43 | db, err := sql.Open("mysql", ctldb.GetTestCtlDBDSN(t)) 44 | if err == nil { 45 | ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) 46 | defer cancel() 47 | _, err = db.ExecContext(ctx, "SELECT 1") 48 | db.Close() 49 | if err == nil { 50 | return 51 | } 52 | } 53 | t.Fatalf(` 54 | *** Tests require MySQL to be up *** 55 | Error: %v" 56 | \n\nHINT: Have you ran 'docker-compose up'?\n\n 57 | `, err) 58 | } 59 | -------------------------------------------------------------------------------- /pkg/tests/tests_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func TestWithTmpDir(t *testing.T) { 11 | td1, td1td := WithTmpDir(t) 12 | td2, td2td := WithTmpDir(t) 13 | 14 | // they can't be the same dirs 15 | require.NotEqual(t, td1, td2) 16 | 17 | exists := func(path string) bool { 18 | _, err := os.Stat(path) 19 | return err == nil 20 | } 21 | require.True(t, exists(td1)) 22 | require.True(t, exists(td2)) 23 | 24 | td1td() 25 | require.False(t, exists(td1)) 26 | require.True(t, exists(td2)) 27 | 28 | td2td() 29 | require.False(t, exists(td1)) 30 | require.False(t, exists(td2)) 31 | } 32 | 33 | func TestWithTmpFile(t *testing.T) { 34 | f1, f1td := WithTmpFile(t, "foo") 35 | f2, f2td := WithTmpFile(t, "foo") 36 | 37 | require.NotEqual(t, f1.Name(), f2.Name()) 38 | 39 | exists := func(file *os.File) bool { 40 | _, err := os.Stat(file.Name()) 41 | return err == nil 42 | } 43 | require.True(t, exists(f1)) 44 | require.True(t, exists(f2)) 45 | 46 | f1td() 47 | require.False(t, exists(f1)) 48 | require.True(t, exists(f2)) 49 | 50 | f2td() 51 | require.False(t, exists(f1)) 52 | require.False(t, exists(f2)) 53 | 54 | f1.Close() 55 | } 56 | -------------------------------------------------------------------------------- /pkg/units/units.go: -------------------------------------------------------------------------------- 1 | package units 2 | 3 | const ( 4 | BYTE = 1.0 << (10 * iota) 5 | KILOBYTE 6 | MEGABYTE 7 | GIGABYTE 8 | TERABYTE 9 | ) 10 | -------------------------------------------------------------------------------- /pkg/unsafe/unsafe.go: -------------------------------------------------------------------------------- 1 | package unsafe 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "unsafe" 7 | ) 8 | 9 | type iptr struct { 10 | itab unsafe.Pointer 11 | ptr unsafe.Pointer 12 | } 13 | 14 | func (p iptr) String() string { 15 | return fmt.Sprintf("itab=0x%x, ptr=0x%x", uintptr(p.itab), uintptr(p.ptr)) 16 | } 17 | 18 | func (p iptr) Interface() interface{} { 19 | return *(*interface{})(unsafe.Pointer(&p)) 20 | } 21 | 22 | type InterfaceFactory struct { 23 | ptritab unsafe.Pointer 24 | } 25 | 26 | func NewInterfaceFactory(t reflect.Type) InterfaceFactory { 27 | // Need something valid to point at, just a throwaway 28 | tmp := struct{}{} 29 | ptr := unsafe.Pointer(&tmp) 30 | 31 | // Construct a pointer of type *t that points at tmp 32 | ptrPtrVal := reflect.NewAt(t, ptr) 33 | 34 | // Build interface{*t, ptrPtrVal=>tmp} 35 | ptrPtrIface := ptrPtrVal.Interface() 36 | 37 | // Coerce the above interface{} into a touchable struct 38 | ptrPtrIptr := *(*iptr)(unsafe.Pointer(&ptrPtrIface)) 39 | 40 | // All we care about is the itab field, which contains the 41 | // type information to copy onto factory-created interface{}s 42 | itabPtr := unsafe.Pointer(ptrPtrIptr.itab) 43 | 44 | return InterfaceFactory{ 45 | ptritab: itabPtr, 46 | } 47 | } 48 | 49 | // takes interface{Struct, ptr=>struct}, returns interface{*FieldType, ptr=>&struct.field} 50 | func (f *InterfaceFactory) PtrToStructField(any interface{}, field reflect.StructField) interface{} { 51 | // creates an iptr struct out of the 'any' interface 52 | anyIptr := *(*iptr)(unsafe.Pointer(&any)) 53 | 54 | // construct the new pointer (&struct.field) that will be returned 55 | interptr := unsafe.Pointer(uintptr(anyIptr.ptr) + field.Offset) 56 | 57 | // create a new iptr by copying the template iptr, which has the proper type 58 | newIptr := iptr{ 59 | itab: f.ptritab, 60 | ptr: interptr, 61 | } 62 | 63 | return newIptr.Interface() 64 | } 65 | -------------------------------------------------------------------------------- /pkg/unsafe/unsafe_test.go: -------------------------------------------------------------------------------- 1 | // +build !race 2 | 3 | package unsafe 4 | 5 | import ( 6 | "reflect" 7 | "testing" 8 | 9 | "github.com/google/go-cmp/cmp" 10 | ) 11 | 12 | func TestInterfaceFactoryPtrToStructField(t *testing.T) { 13 | myStruct := struct{ X, Y string }{"hello", "world"} 14 | myStructX := reflect.TypeOf(myStruct).Field(0) 15 | myStructY := reflect.TypeOf(myStruct).Field(1) 16 | xif := NewInterfaceFactory(myStructX.Type) 17 | yif := NewInterfaceFactory(myStructY.Type) 18 | 19 | xptr := xif.PtrToStructField(&myStruct, myStructX) 20 | yptr := yif.PtrToStructField(&myStruct, myStructY) 21 | 22 | t.Logf("xptr=%v:%v, yptr=%v:%v\n", 23 | xptr, 24 | reflect.TypeOf(xptr), 25 | yptr, 26 | reflect.TypeOf(yptr)) 27 | 28 | *(xptr.(*string)) = "goodbye" 29 | *(yptr.(*string)) = "earth" 30 | 31 | if diff := cmp.Diff(myStruct, struct{ X, Y string }{"goodbye", "earth"}); diff != "" { 32 | t.Errorf("Mismatch struct\n%v", diff) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /pkg/utils/atomic_bool.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import "sync/atomic" 4 | 5 | type AtomicBool int32 6 | 7 | func (b *AtomicBool) IsSet() bool { 8 | return atomic.LoadInt32((*int32)(b)) != 0 9 | } 10 | 11 | func (b *AtomicBool) SetTrue() { 12 | atomic.StoreInt32((*int32)(b), 1) 13 | } 14 | 15 | func (b *AtomicBool) SetFalse() { 16 | atomic.StoreInt32((*int32)(b), 0) 17 | } 18 | -------------------------------------------------------------------------------- /pkg/utils/doc.go: -------------------------------------------------------------------------------- 1 | // this package hosts utilities that probably don't belong elsewhere. 2 | // a utils package is fine. 3 | // rule: nothing in this package can depend on any other package in this project. 4 | package utils 5 | -------------------------------------------------------------------------------- /pkg/utils/ensure_dir.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | 7 | "github.com/pkg/errors" 8 | ) 9 | 10 | // EnsureDirForFile ensures that the specified file's parent directory 11 | // exists. 12 | func EnsureDirForFile(file string) error { 13 | dir := filepath.Dir(file) 14 | _, err := os.Stat(dir) 15 | switch { 16 | case err == nil: 17 | return nil 18 | case os.IsNotExist(err): 19 | err = os.Mkdir(dir, 0700) 20 | return errors.Wrapf(err, "mkdir %s", dir) 21 | default: 22 | return errors.Wrapf(err, "stat %s", dir) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /pkg/utils/interface_slice.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import "reflect" 4 | 5 | // Converts args to an interface slice using the following rules: 6 | // 7 | // - Returns empty slice if no args are passed 8 | // 9 | // - For a single argument which is of a slice type, the slice 10 | // is converted and returned. 11 | // 12 | // - For a single argument which is not a slice type, the value is 13 | // returned within a single-element slice. 14 | // 15 | // - For multiple arguments, returns a slice with all the args 16 | // 17 | func InterfaceSlice(any ...interface{}) []interface{} { 18 | if len(any) == 0 { 19 | return []interface{}{} 20 | } 21 | 22 | if len(any) == 1 { 23 | // FUTURE: there has to be a faster way to do this right? I guess 24 | // that under the hood this is what happens to the arguments 25 | // passed to the function. I'd assume it can elide a bunch of the 26 | // reflection given it knows the types. 27 | v := reflect.ValueOf(any[0]) 28 | if v.Type().Kind() == reflect.Slice { 29 | vLen := v.Len() 30 | out := make([]interface{}, vLen) 31 | for i := 0; i < vLen; i++ { 32 | out[i] = v.Index(i).Interface() 33 | } 34 | return out 35 | } 36 | 37 | return []interface{}{any} 38 | } 39 | 40 | return any 41 | } 42 | -------------------------------------------------------------------------------- /pkg/utils/json_reader.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "io" 7 | ) 8 | 9 | // JsonReader is a convenience type that is constructed with a 10 | // type to be serialized using newJsonReader. it implements 11 | // io.Reader and writes JSON bytes to the client. Useful for 12 | // supplying a reader for the body of an http request. This 13 | // allows the client to omit the extra step of encoding a struct 14 | // into a byte slice and then passing a bytes.NewReader(b) to 15 | // something expecting that reader. 16 | type JsonReader struct { 17 | reader io.Reader 18 | err error 19 | } 20 | 21 | func (r *JsonReader) Read(p []byte) (n int, err error) { 22 | if r.err != nil { 23 | return -1, err 24 | } 25 | return r.reader.Read(p) 26 | } 27 | 28 | func NewJsonReader(val interface{}) *JsonReader { 29 | b, err := json.Marshal(val) 30 | return &JsonReader{ 31 | reader: bytes.NewReader(b), 32 | err: err, 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /pkg/utils/looper.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "context" 5 | "time" 6 | ) 7 | 8 | // CtxLoop blocks and fires the callback function on a tick. 9 | func CtxLoop(ctx context.Context, delay time.Duration, fn func()) { 10 | ticker := time.NewTicker(delay) 11 | defer ticker.Stop() 12 | CtxLoopTicker(ctx, ticker, fn) 13 | } 14 | 15 | // CtxLoopTicker blocks and fires the callback function on a tick. 16 | func CtxLoopTicker(ctx context.Context, ticker *time.Ticker, fn func()) { 17 | for { 18 | select { 19 | case <-ctx.Done(): 20 | return 21 | case <-ticker.C: 22 | fn() 23 | } 24 | } 25 | } 26 | 27 | // CtxFireLoop blocks and fires the callback function on a tick. The callback 28 | // function is fired first before the first delay. 29 | func CtxFireLoop(ctx context.Context, delay time.Duration, fn func()) { 30 | ticker := time.NewTicker(delay) 31 | defer ticker.Stop() 32 | CtxFireLoopTicker(ctx, ticker, fn) 33 | } 34 | 35 | // CtxFireLoopTicker blocks and fires the callback function on a tick. The callback 36 | // function is fired first before the first delay. 37 | func CtxFireLoopTicker(ctx context.Context, ticker *time.Ticker, fn func()) { 38 | fn() 39 | for { 40 | select { 41 | case <-ctx.Done(): 42 | return 43 | case <-ticker.C: 44 | fn() 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /pkg/utils/teardowns.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | // Teardowns is meant to make it easy to chain teardown funcs and 4 | // then have them execute in reverse order (like defer) 5 | type Teardowns struct { 6 | funcs []func() 7 | } 8 | 9 | func (t *Teardowns) Add(fn func()) { 10 | t.funcs = append(t.funcs, fn) 11 | } 12 | 13 | func (t *Teardowns) Teardown() { 14 | for i := len(t.funcs) - 1; i >= 0; i-- { 15 | t.funcs[i]() 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /pkg/utils/teardowns_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | func TestTeardownsEmpty(t *testing.T) { 10 | var tds Teardowns 11 | tds.Teardown() 12 | } 13 | 14 | func TestTeardownsOrder(t *testing.T) { 15 | var tds Teardowns 16 | var nums []int 17 | 18 | tds.Add(func() { nums = append(nums, 1) }) 19 | tds.Add(func() { nums = append(nums, 2) }) 20 | tds.Add(func() { nums = append(nums, 3) }) 21 | 22 | tds.Teardown() 23 | require.EqualValues(t, []int{3, 2, 1}, nums) 24 | } 25 | -------------------------------------------------------------------------------- /pkg/version/version.go: -------------------------------------------------------------------------------- 1 | package version 2 | 3 | var version = "unknown" 4 | 5 | // Get returns the version of the ctlstore client library. 6 | func Get() string { 7 | return version 8 | } 9 | -------------------------------------------------------------------------------- /pkg/version/version_go1_12.go: -------------------------------------------------------------------------------- 1 | // +build go1.12 2 | 3 | package version 4 | 5 | import ( 6 | "runtime/debug" 7 | ) 8 | 9 | const path = "github.com/segmentio/ctlstore" 10 | 11 | // The version is extracted from build information embedded in the binary, from 12 | // a go.mod file, so this version field is only available in Go modules projects. 13 | // We determine the version dynamically instead of using -ldflags to inject the 14 | // version because ctlstore will be imported as a library, and we do not expect 15 | // consumers to set ctlstore's version for us. 16 | // Note: `debug.ReadBuildInfo` is only available in Go 1.12+, so we gate this 17 | // with the build tag above. 18 | func init() { 19 | if info, ok := debug.ReadBuildInfo(); ok && info != nil { 20 | for _, mod := range info.Deps { 21 | if mod != nil { 22 | if mod.Path == path { 23 | version = mod.Version 24 | } 25 | } 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /rows.go: -------------------------------------------------------------------------------- 1 | package ctlstore 2 | 3 | import ( 4 | "database/sql" 5 | 6 | "github.com/segmentio/ctlstore/pkg/scanfunc" 7 | "github.com/segmentio/ctlstore/pkg/schema" 8 | ) 9 | 10 | // Rows composes an *sql.Rows and allows scanning ctlstore table rows into 11 | // structs or maps, similar to how the GetRowByKey reader method works. 12 | // 13 | // The contract around Next/Err/Close is the same was it is for 14 | // *sql.Rows. 15 | type Rows struct { 16 | rows *sql.Rows 17 | cols []schema.DBColumnMeta 18 | } 19 | 20 | // Next returns true if there's another row available. 21 | func (r *Rows) Next() bool { 22 | if r.rows == nil { 23 | return false 24 | } 25 | return r.rows.Next() 26 | } 27 | 28 | // Err returns any error that could have been caused during 29 | // the invocation of Next(). If Next() returns false, the caller 30 | // must always check Err() to see if that's why iteration 31 | // failed. 32 | func (r *Rows) Err() error { 33 | if r.rows == nil { 34 | return nil 35 | } 36 | return r.rows.Err() 37 | } 38 | 39 | // Close closes the underlying *sql.Rows. 40 | func (r *Rows) Close() error { 41 | if r.rows == nil { 42 | return nil 43 | } 44 | return r.rows.Close() 45 | } 46 | 47 | // Scan deserializes the current row into the specified target. 48 | // The target must be either a pointer to a struct, or a 49 | // map[string]interface{}. 50 | func (r *Rows) Scan(target interface{}) error { 51 | if r.rows == nil { 52 | return sql.ErrNoRows 53 | } 54 | scanFunc, err := scanfunc.New(target, r.cols) 55 | if err != nil { 56 | return err 57 | } 58 | return scanFunc(r.rows) 59 | } 60 | -------------------------------------------------------------------------------- /scripts/download.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | set -eo pipefail 4 | 5 | CTLSTORE_BOOTSTRAP_URL=$1 6 | PREFIX="$(echo $CTLSTORE_BOOTSTRAP_URL | grep :// | sed -e's,^\(.*://\).*,\1,g')" 7 | URL="$(echo $CTLSTORE_BOOTSTRAP_URL | sed -e s,$PREFIX,,g)" 8 | BUCKET="$(echo $URL | grep / | cut -d/ -f1)" 9 | KEY="$(echo $URL | grep / | cut -d/ -f2)" 10 | CTLSTORE_DIR="/var/spool/ctlstore" 11 | CONCURRENCY=${2:-20} 12 | NUM_LDB=${3:-1} 13 | DOWNLOADED="false" 14 | COMPRESSED="false" 15 | METRICS="$CTLSTORE_DIR/metrics.json" 16 | 17 | mkdir -p $CTLSTORE_DIR 18 | cd $CTLSTORE_DIR 19 | 20 | # busybox does not support sub-second resolution 21 | START=$(date +%s) 22 | END=$(date +%s) 23 | SHA_START=$(date +%s) 24 | SHA_END=$(date +%s) 25 | 26 | get_head_object() { 27 | head_object=$(aws s3api head-object --bucket "${BUCKET}" --key "${KEY}") 28 | echo "$head_object" 29 | } 30 | 31 | cleanup() { 32 | echo "Removing snapshot.db" 33 | rm -f $CTLSTORE_DIR/snapshot.* 34 | } 35 | 36 | download_snapshot() { 37 | echo "Downloading head object from ${CTLSTORE_BOOTSTRAP_URL}" 38 | head_object=$(get_head_object) 39 | 40 | remote_checksum=$(printf '%s\n' "$head_object" | jq -r '.Metadata.checksum // empty') 41 | echo "Remote checksum in sha1: $remote_checksum" 42 | 43 | remote_version=$(printf '%s\n' "$head_object" | jq -r '.VersionId // empty') 44 | echo "Remote version: $remote_version" 45 | 46 | echo "Downloading snapshot from ${CTLSTORE_BOOTSTRAP_URL} with VersionID: ${remote_version}" 47 | s5cmd -r 0 --log debug cp --version-id $remote_version --concurrency $CONCURRENCY $CTLSTORE_BOOTSTRAP_URL . 48 | 49 | DOWNLOADED="true" 50 | if [[ ${CTLSTORE_BOOTSTRAP_URL: -2} == gz ]]; then 51 | echo "Decompressing" 52 | pigz -df snapshot.db.gz 53 | COMPRESSED="true" 54 | fi 55 | } 56 | 57 | check_sha() { 58 | SHA_START=$(date +%s) 59 | if [ -z $remote_checksum ]; then 60 | echo "Remote checksum sha1 is null, skipping checksum validation" 61 | else 62 | local_checksum=$(shasum snapshot.db | cut -f1 -d\ | xxd -r -p | base64) 63 | echo "Local snapshot checksum in sha1: $local_checksum" 64 | 65 | if [[ "$local_checksum" == "$remote_checksum" ]]; then 66 | echo "Checksum matches" 67 | else 68 | echo "Checksum does not match" 69 | echo "Failed to download intact snapshot" 70 | cleanup 71 | exit 1 72 | fi 73 | fi 74 | SHA_END=$(date +%s) 75 | echo "Local checksum calculation took $(($SHA_END - $SHA_START)) seconds" 76 | } 77 | 78 | if [ ! -f "$CTLSTORE_DIR/ldb.db" ]; then 79 | echo "No ldb found, downloading snapshot" 80 | download_snapshot 81 | check_sha 82 | 83 | i=2 84 | while [ "$i" -le $NUM_LDB ]; do 85 | if [ ! -f ldb-$i.db ]; then 86 | echo "creating copy ldb-$i.db" 87 | cp snapshot.db ldb-$i.db 88 | fi 89 | i=$((i + 1)) 90 | done 91 | 92 | mv snapshot.db ldb.db 93 | END=$(date +%s) 94 | echo "ldb.db ready in $(($END - $START)) seconds" 95 | else 96 | echo "Snapshot already present" 97 | fi 98 | 99 | # on existing nodes, we may already have the ldb file. 100 | # We should download a new snapshot to avoid copying an in-use ldb.db file and risking a malformed db 101 | i=2 102 | while [ "$i" -le $NUM_LDB ]; do 103 | 104 | # make sure it's not already downloaded 105 | if [ ! -f ldb-$i.db ]; then 106 | echo "Preparing ldb-$i.db" 107 | # download the snapshot if it's not present 108 | if [ ! -f "$CTLSTORE_DIR/snapshot.db" ]; then 109 | download_snapshot 110 | check_sha 111 | fi 112 | 113 | echo "creating copy ldb-$i.db" 114 | cp snapshot.db ldb-$i.db 115 | fi 116 | i=$((i + 1)) 117 | done 118 | 119 | cleanup 120 | 121 | echo "{\"startTime\": $(($END - $START)), \"downloaded\": \"$DOWNLOADED\", \"compressed\": \"$COMPRESSED\"}" >$METRICS 122 | cat $METRICS 123 | -------------------------------------------------------------------------------- /tools.go: -------------------------------------------------------------------------------- 1 | // +build tools 2 | 3 | package ctlstore 4 | 5 | import ( 6 | _ "github.com/maxbrunsfeld/counterfeiter/v6" 7 | ) 8 | 9 | // This file imports packages that are used when running go generate, or used 10 | // during the development process but not otherwise depended on by built code. 11 | -------------------------------------------------------------------------------- /version.go: -------------------------------------------------------------------------------- 1 | package ctlstore 2 | 3 | import "github.com/segmentio/ctlstore/pkg/version" 4 | 5 | // Version is the current ctlstore client library version. 6 | var Version string 7 | 8 | func init() { 9 | Version = version.Get() 10 | } 11 | --------------------------------------------------------------------------------