├── .editorconfig ├── .github ├── renovate.json5 └── workflows │ └── go.yml ├── LICENSE ├── README.md ├── cmd └── simpledb │ └── main.go ├── go.mod ├── go.sum └── internal ├── buffer ├── buffer.go └── buffer_test.go ├── file ├── file.go └── file_test.go ├── index ├── btree.go ├── btree_test.go ├── index.go ├── scan.go └── scan_test.go ├── log ├── log.go └── log_test.go ├── metadata ├── metadata.go └── metadata_test.go ├── parse ├── lexer.go ├── lexer_test.go ├── parser.go └── parser_test.go ├── plan ├── error.go ├── group.go ├── group_test.go ├── materialize.go ├── mergejoin.go ├── mergejoin_test.go ├── multibuffer.go ├── multibuffer_test.go ├── plan.go ├── planner.go ├── planner_heuristic_test.go ├── planner_test.go ├── sort.go └── sort_test.go ├── postgres ├── postgres.go ├── server.go ├── server_bench_test.go └── server_test.go ├── query ├── query.go └── query_test.go ├── record ├── record.go ├── record_test.go └── schema │ └── schema.go ├── simpledb └── simpledb.go ├── statement └── statement.go ├── testdata ├── create_indexes.sql ├── create_tables.sql ├── example.sql ├── insert_data.sql ├── snapshots │ ├── tables │ │ ├── field_catalog.tbl │ │ ├── simpledb.log │ │ ├── table_catalog.tbl │ │ └── view_catalog.tbl │ ├── tables_data │ │ ├── courses.tbl │ │ ├── departments.tbl │ │ ├── field_catalog.tbl │ │ ├── index_catalog.tbl │ │ ├── sections.tbl │ │ ├── simpledb.log │ │ ├── students.tbl │ │ ├── table_catalog.tbl │ │ └── view_catalog.tbl │ └── tables_indexes_data │ │ ├── courses.tbl │ │ ├── courses_course_department_id_dir │ │ ├── courses_course_department_id_leaf │ │ ├── courses_pkey_dir │ │ ├── courses_pkey_leaf │ │ ├── departments.tbl │ │ ├── departments_pkey_dir │ │ ├── departments_pkey_leaf │ │ ├── field_catalog.tbl │ │ ├── index_catalog.tbl │ │ ├── sections.tbl │ │ ├── sections_pkey_dir │ │ ├── sections_pkey_leaf │ │ ├── sections_section_course_id_dir │ │ ├── sections_section_course_id_leaf │ │ ├── simpledb.log │ │ ├── students.tbl │ │ ├── students_pkey_dir │ │ ├── students_pkey_leaf │ │ ├── table_catalog.tbl │ │ └── view_catalog.tbl ├── testdata.go └── testdata_test.go └── transaction ├── transaction.go └── transaction_test.go /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | end_of_line = lf 5 | charset = utf-8 6 | trim_trailing_whitespace = true 7 | insert_final_newline = true 8 | -------------------------------------------------------------------------------- /.github/renovate.json5: -------------------------------------------------------------------------------- 1 | { 2 | "extends": ["config:base"], 3 | "packageRules": [ 4 | { 5 | "automerge": true, 6 | "matchUpdateTypes": ["minor", "patch"] 7 | } 8 | ], 9 | "major": { 10 | "automerge": false 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | lint-test-build: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Set up Go 15 | uses: actions/setup-go@v5 16 | with: 17 | go-version: '1.24' 18 | check-latest: true 19 | - uses: dominikh/staticcheck-action@v1 20 | with: 21 | install-go: false 22 | version: "latest" 23 | - name: Vet 24 | run: go vet ./... 25 | - name: Test 26 | run: go test -coverprofile=coverage.txt -v ./... 27 | - name: Build 28 | run: go build -v ./... 29 | - name: Upload results to Codecov 30 | uses: codecov/codecov-action@v4 31 | with: 32 | token: ${{ secrets.CODECOV_TOKEN }} 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Kotaro Abe 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # simple-db 2 | 3 | Implementation of the SimpleDB database system in Go. 4 | 5 | http://www.cs.bc.edu/~sciore/simpledb/ 6 | 7 | ## Usage 8 | 9 | Launch the SimpleDB server with the following command: 10 | 11 | ```sh 12 | go run ./cmd/simpledb -dir "path/to/db" 13 | ``` 14 | 15 | Clients can connect to the server using the `psql` command line tool: 16 | 17 | ```sh 18 | psql -h localhost -p 45432 19 | ``` 20 | 21 | ## Supported SQL Commands / Examples 22 | 23 | See below files: 24 | - [internal/testdata/example.sql](internal/testdata/example.sql) 25 | - [internal/postgres/server_test.go](internal/postgres/server_test.go) -------------------------------------------------------------------------------- /cmd/simpledb/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "log" 7 | 8 | "github.com/abekoh/simple-db/internal/postgres" 9 | ) 10 | 11 | func main() { 12 | var dir string 13 | flag.StringVar(&dir, "dir", "", "directory to store data") 14 | flag.Parse() 15 | 16 | ctx := context.Background() 17 | cfg := postgres.Config{ 18 | Dir: dir, 19 | } 20 | if err := postgres.RunServer(ctx, cfg); err != nil { 21 | log.Fatal(err) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/abekoh/simple-db 2 | 3 | go 1.23.0 4 | 5 | require ( 6 | github.com/brianvoe/gofakeit/v7 v7.2.1 7 | github.com/google/go-cmp v0.7.0 8 | github.com/jackc/pgx/v5 v5.7.5 9 | github.com/oklog/ulid/v2 v2.1.1 10 | golang.org/x/sync v0.15.0 11 | ) 12 | 13 | require ( 14 | github.com/jackc/pgpassfile v1.0.0 // indirect 15 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect 16 | golang.org/x/crypto v0.37.0 // indirect 17 | golang.org/x/text v0.24.0 // indirect 18 | ) 19 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/brianvoe/gofakeit/v7 v7.0.4 h1:Mkxwz9jYg8Ad8NvT9HA27pCMZGFQo08MK6jD0QTKEww= 2 | github.com/brianvoe/gofakeit/v7 v7.0.4/go.mod h1:QXuPeBw164PJCzCUZVmgpgHJ3Llj49jSLVkKPMtxtxA= 3 | github.com/brianvoe/gofakeit/v7 v7.1.1/go.mod h1:QXuPeBw164PJCzCUZVmgpgHJ3Llj49jSLVkKPMtxtxA= 4 | github.com/brianvoe/gofakeit/v7 v7.1.2/go.mod h1:QXuPeBw164PJCzCUZVmgpgHJ3Llj49jSLVkKPMtxtxA= 5 | github.com/brianvoe/gofakeit/v7 v7.2.0/go.mod h1:QXuPeBw164PJCzCUZVmgpgHJ3Llj49jSLVkKPMtxtxA= 6 | github.com/brianvoe/gofakeit/v7 v7.2.1/go.mod h1:QXuPeBw164PJCzCUZVmgpgHJ3Llj49jSLVkKPMtxtxA= 7 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 8 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 9 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 10 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 11 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 12 | github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= 13 | github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= 14 | github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= 15 | github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= 16 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= 17 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= 18 | github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= 19 | github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= 20 | github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs= 21 | github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA= 22 | github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI= 23 | github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= 24 | github.com/jackc/pgx/v5 v5.7.3 h1:PO1wNKj/bTAwxSJnO1Z4Ai8j4magtqg2SLNjEDzcXQo= 25 | github.com/jackc/pgx/v5 v5.7.3/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= 26 | github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg= 27 | github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= 28 | github.com/jackc/pgx/v5 v5.7.5 h1:JHGfMnQY+IEtGM63d+NGMjoRpysB2JBwDr5fsngwmJs= 29 | github.com/jackc/pgx/v5 v5.7.5/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= 30 | github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= 31 | github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= 32 | github.com/oklog/ulid/v2 v2.1.0 h1:+9lhoxAP56we25tyYETBBY1YLA2SaoLvUFgrP2miPJU= 33 | github.com/oklog/ulid/v2 v2.1.0/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ= 34 | github.com/oklog/ulid/v2 v2.1.1 h1:suPZ4ARWLOJLegGFiZZ1dFAkqzhMjL3J1TzI+5wHz8s= 35 | github.com/oklog/ulid/v2 v2.1.1/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ= 36 | github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= 37 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 38 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 39 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 40 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 41 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 42 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 43 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 44 | golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= 45 | golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= 46 | golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= 47 | golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= 48 | golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= 49 | golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= 50 | golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= 51 | golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= 52 | golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= 53 | golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 54 | golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= 55 | golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 56 | golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= 57 | golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 58 | golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= 59 | golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 60 | golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= 61 | golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 62 | golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= 63 | golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 64 | golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= 65 | golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 66 | golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= 67 | golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 68 | golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= 69 | golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= 70 | golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= 71 | golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= 72 | golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= 73 | golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= 74 | golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= 75 | golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= 76 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 77 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 78 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 79 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 80 | -------------------------------------------------------------------------------- /internal/buffer/buffer.go: -------------------------------------------------------------------------------- 1 | package buffer 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync/atomic" 7 | "time" 8 | 9 | "github.com/abekoh/simple-db/internal/file" 10 | "github.com/abekoh/simple-db/internal/log" 11 | ) 12 | 13 | type Buffer struct { 14 | fm *file.Manager 15 | lm *log.Manager 16 | page *file.Page 17 | blockID file.BlockID 18 | pinsCount int32 19 | txNum int32 20 | lsn log.SequenceNumber 21 | } 22 | 23 | func NewBuffer(fm *file.Manager, lm *log.Manager) *Buffer { 24 | return &Buffer{ 25 | fm: fm, 26 | lm: lm, 27 | page: file.NewPage(fm.BlockSize()), 28 | txNum: -1, 29 | lsn: -1, 30 | } 31 | } 32 | 33 | func (b *Buffer) Page() *file.Page { 34 | return b.page 35 | } 36 | 37 | func (b *Buffer) BlockID() file.BlockID { 38 | return b.blockID 39 | } 40 | 41 | func (b *Buffer) TxNum() int32 { 42 | return b.txNum 43 | } 44 | 45 | func (b *Buffer) SetModified(txNum int32, lsn log.SequenceNumber) { 46 | b.txNum = txNum 47 | if lsn >= 0 { 48 | b.lsn = lsn 49 | } 50 | } 51 | 52 | func (b *Buffer) IsPinned() bool { 53 | return b.pinsCount > 0 54 | } 55 | 56 | func (b *Buffer) assignedToBlock(blockID file.BlockID) error { 57 | if err := b.flush(); err != nil { 58 | return fmt.Errorf("could not flush: %w", err) 59 | } 60 | b.blockID = blockID 61 | if err := b.fm.Read(blockID, b.page); err != nil { 62 | return fmt.Errorf("could not read: %w", err) 63 | } 64 | b.pinsCount = 0 65 | return nil 66 | } 67 | 68 | func (b *Buffer) flush() error { 69 | if b.txNum < 0 { 70 | return nil 71 | } 72 | if err := b.lm.Flush(b.lsn); err != nil { 73 | return fmt.Errorf("could not flush: %w", err) 74 | } 75 | if err := b.fm.Write(b.blockID, b.page); err != nil { 76 | return fmt.Errorf("could not write: %w", err) 77 | } 78 | b.txNum = -1 79 | return nil 80 | } 81 | 82 | func (b *Buffer) pin() { 83 | b.pinsCount++ 84 | } 85 | 86 | func (b *Buffer) unpin() { 87 | if b.pinsCount <= 0 { 88 | panic("unpin: pinsCount is already 0") 89 | } 90 | b.pinsCount-- 91 | } 92 | 93 | type ( 94 | Manager struct { 95 | pool []*Buffer 96 | availableNum atomic.Int32 97 | pinRequestCh chan pinRequest 98 | unpinCh chan unpinRequest 99 | flushAllCh chan flushAllRequest 100 | maxWaitTime time.Duration 101 | } 102 | ManagerOption func(*Manager) 103 | ) 104 | 105 | type ( 106 | bufferResult struct { 107 | buf *Buffer 108 | err error 109 | } 110 | pinRequest struct { 111 | blockID file.BlockID 112 | receiveCh chan<- bufferResult 113 | cancelCh <-chan struct{} 114 | } 115 | unpinRequest struct { 116 | buf *Buffer 117 | completeCh chan<- struct{} 118 | } 119 | flushAllRequest struct { 120 | txNum int32 121 | errCh chan<- error 122 | } 123 | ) 124 | 125 | const defaultMaxWaitTime = 10 * time.Second 126 | 127 | func NewManager( 128 | ctx context.Context, 129 | fm *file.Manager, 130 | lm *log.Manager, 131 | buffNum int, 132 | opts ...ManagerOption, 133 | ) *Manager { 134 | pool := make([]*Buffer, buffNum) 135 | for i := range pool { 136 | pool[i] = NewBuffer(fm, lm) 137 | } 138 | m := &Manager{ 139 | pool: pool, 140 | availableNum: atomic.Int32{}, 141 | pinRequestCh: make(chan pinRequest), 142 | unpinCh: make(chan unpinRequest), 143 | flushAllCh: make(chan flushAllRequest), 144 | maxWaitTime: defaultMaxWaitTime, 145 | } 146 | m.availableNum.Store(int32(buffNum)) 147 | for _, opt := range opts { 148 | opt(m) 149 | } 150 | go m.loop(ctx) 151 | return m 152 | } 153 | 154 | func WithMaxWaitTime(d time.Duration) ManagerOption { 155 | return func(m *Manager) { 156 | m.maxWaitTime = d 157 | } 158 | } 159 | 160 | func (m *Manager) loop(ctx context.Context) { 161 | waitMap := make(map[file.BlockID][]pinRequest) 162 | for { 163 | select { 164 | case <-ctx.Done(): 165 | return 166 | case flushAllReq := <-m.flushAllCh: 167 | var err error 168 | for _, b := range m.pool { 169 | if b.txNum != flushAllReq.txNum { 170 | continue 171 | } 172 | if err := b.flush(); err != nil { 173 | break 174 | } 175 | } 176 | flushAllReq.errCh <- err 177 | case unpinReq := <-m.unpinCh: 178 | unpinReq.buf.unpin() 179 | if len(waitMap[unpinReq.buf.blockID]) > 0 { 180 | unpinReq.buf.pin() 181 | req := waitMap[unpinReq.buf.blockID][0] 182 | select { 183 | case <-req.cancelCh: 184 | case req.receiveCh <- bufferResult{buf: unpinReq.buf}: 185 | } 186 | if len(waitMap[unpinReq.buf.blockID]) > 1 { 187 | waitMap[unpinReq.buf.blockID] = waitMap[unpinReq.buf.blockID][1:] 188 | } else { 189 | delete(waitMap, unpinReq.buf.blockID) 190 | } 191 | } else if !unpinReq.buf.IsPinned() { 192 | m.availableNum.Add(1) 193 | } 194 | unpinReq.completeCh <- struct{}{} 195 | case pinReq := <-m.pinRequestCh: 196 | var b *Buffer 197 | for _, buf := range m.pool { 198 | if buf.blockID == pinReq.blockID { 199 | b = buf 200 | break 201 | } 202 | } 203 | if b != nil { 204 | if !b.IsPinned() { 205 | m.availableNum.Add(-1) 206 | } 207 | b.pin() 208 | pinReq.receiveCh <- bufferResult{buf: b} 209 | } else { 210 | received := false 211 | for _, b := range m.pool { 212 | if !b.IsPinned() { 213 | err := b.assignedToBlock(pinReq.blockID) 214 | if err != nil { 215 | pinReq.receiveCh <- bufferResult{err: err} 216 | received = true 217 | break 218 | } 219 | b.pin() 220 | m.availableNum.Add(-1) 221 | pinReq.receiveCh <- bufferResult{buf: b} 222 | received = true 223 | break 224 | } 225 | } 226 | if !received { 227 | select { 228 | case <-pinReq.cancelCh: 229 | continue 230 | default: 231 | } 232 | if _, ok := waitMap[pinReq.blockID]; !ok { 233 | waitMap[pinReq.blockID] = []pinRequest{pinReq} 234 | } else { 235 | waitMap[pinReq.blockID] = append(waitMap[pinReq.blockID], pinReq) 236 | } 237 | } 238 | } 239 | } 240 | } 241 | } 242 | 243 | func (m *Manager) AvailableNum() int { 244 | return int(m.availableNum.Load()) 245 | } 246 | 247 | func (m *Manager) FlushAll(txNum int32) error { 248 | ch := make(chan error) 249 | m.flushAllCh <- flushAllRequest{txNum: txNum, errCh: ch} 250 | return <-ch 251 | } 252 | 253 | func (m *Manager) Pin(blockID file.BlockID) (*Buffer, error) { 254 | resCh := make(chan bufferResult) 255 | cancelCh := make(chan struct{}, 1) 256 | m.pinRequestCh <- pinRequest{blockID: blockID, receiveCh: resCh, cancelCh: cancelCh} 257 | select { 258 | case res := <-resCh: 259 | return res.buf, res.err 260 | case <-time.After(m.maxWaitTime): 261 | cancelCh <- struct{}{} 262 | return nil, fmt.Errorf("could not pin %v", blockID) 263 | } 264 | } 265 | 266 | func (m *Manager) Unpin(b *Buffer) { 267 | ch := make(chan struct{}) 268 | m.unpinCh <- unpinRequest{buf: b, completeCh: ch} 269 | <-ch 270 | } 271 | -------------------------------------------------------------------------------- /internal/buffer/buffer_test.go: -------------------------------------------------------------------------------- 1 | package buffer 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/abekoh/simple-db/internal/file" 9 | "github.com/abekoh/simple-db/internal/log" 10 | ) 11 | 12 | func mustPin(t *testing.T, bm *Manager, blockID file.BlockID) *Buffer { 13 | t.Helper() 14 | buf, err := bm.Pin(blockID) 15 | if err != nil { 16 | t.Fatal(err) 17 | } 18 | return buf 19 | } 20 | 21 | func assertAvailableNum(t *testing.T, bm *Manager, expected int) { 22 | t.Helper() 23 | if bm.AvailableNum() != expected { 24 | t.Errorf("expected %d, got %d", expected, bm.AvailableNum()) 25 | } 26 | } 27 | 28 | func TestBufferManager(t *testing.T) { 29 | t.Parallel() 30 | t.Run("Pin and Unpin", func(t *testing.T) { 31 | t.Parallel() 32 | fm, err := file.NewManager(t.TempDir(), 128) 33 | if err != nil { 34 | t.Fatal(err) 35 | } 36 | lm, err := log.NewManager(fm, "logfile") 37 | if err != nil { 38 | t.Fatal(err) 39 | } 40 | ctx := context.Background() 41 | bm := NewManager(ctx, fm, lm, 3, WithMaxWaitTime(10*time.Millisecond)) 42 | 43 | assertAvailableNum(t, bm, 3) 44 | 45 | bufs := make([]*Buffer, 6) 46 | bufs[0] = mustPin(t, bm, file.NewBlockID("testfile", 0)) 47 | assertAvailableNum(t, bm, 2) 48 | 49 | bufs[1] = mustPin(t, bm, file.NewBlockID("testfile", 1)) 50 | bufs[2] = mustPin(t, bm, file.NewBlockID("testfile", 2)) 51 | assertAvailableNum(t, bm, 0) 52 | 53 | bufs[3] = mustPin(t, bm, file.NewBlockID("testfile", 0)) 54 | bm.Unpin(bufs[1]) 55 | bufs[1] = nil 56 | assertAvailableNum(t, bm, 1) 57 | 58 | bufs[3] = mustPin(t, bm, file.NewBlockID("testfile", 0)) 59 | bufs[4] = mustPin(t, bm, file.NewBlockID("testfile", 1)) 60 | assertAvailableNum(t, bm, 0) 61 | 62 | _, err = bm.Pin(file.NewBlockID("testfile", 3)) 63 | if err != nil && err.Error() != "could not pin testfile:3" { 64 | t.Errorf("expected could not pin testfile:3, got %s", err) 65 | } 66 | 67 | bm.Unpin(bufs[2]) 68 | bufs[2] = nil 69 | assertAvailableNum(t, bm, 1) 70 | 71 | bufs[5] = mustPin(t, bm, file.NewBlockID("testfile", 3)) 72 | assertAvailableNum(t, bm, 0) 73 | 74 | if bufs[0].BlockID() != file.NewBlockID("testfile", 0) { 75 | t.Errorf("expected testfile:0, got %s", bufs[0].BlockID()) 76 | } 77 | if bufs[3].BlockID() != file.NewBlockID("testfile", 0) { 78 | t.Errorf("expected testfile:0, got %s", bufs[3].BlockID()) 79 | } 80 | if bufs[4].BlockID() != file.NewBlockID("testfile", 1) { 81 | t.Errorf("expected testfile:1, got %s", bufs[4].BlockID()) 82 | } 83 | if bufs[5].BlockID() != file.NewBlockID("testfile", 3) { 84 | t.Errorf("expected testfile:3, got %s", bufs[5].BlockID()) 85 | } 86 | }) 87 | t.Run("FlushAll", func(t *testing.T) { 88 | t.Parallel() 89 | fm, err := file.NewManager(t.TempDir(), 128) 90 | if err != nil { 91 | t.Fatal(err) 92 | } 93 | lm, err := log.NewManager(fm, "logfile") 94 | if err != nil { 95 | t.Fatal(err) 96 | } 97 | ctx := context.Background() 98 | bm := NewManager(ctx, fm, lm, 3, WithMaxWaitTime(10*time.Millisecond)) 99 | 100 | buf1 := mustPin(t, bm, file.NewBlockID("testfile", 0)) 101 | buf1.Page().SetStr(0, "abcdefgh") 102 | buf1.SetModified(1, 1) 103 | 104 | if err := bm.FlushAll(1); err != nil { 105 | t.Fatal(err) 106 | } 107 | }) 108 | } 109 | -------------------------------------------------------------------------------- /internal/file/file.go: -------------------------------------------------------------------------------- 1 | package file 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path" 7 | "sync" 8 | ) 9 | 10 | const int32Size = 4 11 | 12 | type BlockID struct { 13 | filename string 14 | blkNum int32 15 | } 16 | 17 | func NewBlockID(filename string, blkNum int32) BlockID { 18 | return BlockID{ 19 | filename: filename, 20 | blkNum: blkNum, 21 | } 22 | } 23 | 24 | func (b BlockID) String() string { 25 | return fmt.Sprintf("%s:%d", b.filename, b.blkNum) 26 | } 27 | 28 | func (b BlockID) Filename() string { 29 | return b.filename 30 | } 31 | 32 | func (b BlockID) Num() int32 { 33 | return b.blkNum 34 | } 35 | 36 | type Page struct { 37 | bb []byte 38 | } 39 | 40 | func NewPage(blockSize int32) *Page { 41 | return &Page{ 42 | bb: make([]byte, blockSize), 43 | } 44 | } 45 | 46 | func NewPageBytes(b []byte) *Page { 47 | return &Page{ 48 | bb: b, 49 | } 50 | } 51 | 52 | func (p *Page) Int32(offset int32) int32 { 53 | d := p.bb[offset : offset+4] 54 | return int32(d[3])<<24 | int32(d[2])<<16 | int32(d[1])<<8 | int32(d[0]) 55 | } 56 | 57 | func (p *Page) SetInt32(offset int32, n int32) { 58 | p.bb[offset] = byte(n) 59 | p.bb[offset+1] = byte(n >> 8) 60 | p.bb[offset+2] = byte(n >> 16) 61 | p.bb[offset+3] = byte(n >> 24) 62 | } 63 | 64 | func (p *Page) RawBytes() []byte { 65 | return p.bb 66 | } 67 | 68 | func (p *Page) Bytes(offset int32) []byte { 69 | n := p.Int32(offset) 70 | return p.bb[offset+int32Size : offset+int32Size+n] 71 | } 72 | 73 | func (p *Page) SetBytes(offset int32, b []byte) { 74 | p.SetInt32(offset, int32(len(b))) 75 | copy(p.bb[offset+int32Size:offset+int32Size+int32(len(b))], b) 76 | } 77 | 78 | func (p *Page) Str(offset int32) string { 79 | return string(p.Bytes(offset)) 80 | } 81 | 82 | func (p *Page) SetStr(offset int32, s string) { 83 | p.SetBytes(offset, []byte(s)) 84 | } 85 | 86 | func PageStrMaxLengthByStr(s string) int32 { 87 | return PageStrMaxLength(int32(len(s))) 88 | } 89 | 90 | func PageStrMaxLength(l int32) int32 { 91 | return int32Size + l 92 | } 93 | 94 | type Manager struct { 95 | dbDirPath string 96 | blockSize int32 97 | isNew bool 98 | openFiles map[string]*os.File 99 | kmu keyedMutex 100 | } 101 | 102 | func NewManager(dbDirPath string, blockSize int32) (*Manager, error) { 103 | isNew := false 104 | fi, err := os.Stat(dbDirPath) 105 | if err != nil { 106 | isNew = true 107 | if err := os.Mkdir(dbDirPath, 0755); err != nil { 108 | return nil, err 109 | } 110 | } else { 111 | if !fi.IsDir() { 112 | return nil, fmt.Errorf("%s is not a directory", dbDirPath) 113 | } 114 | files, err := os.ReadDir(dbDirPath) 115 | if err != nil { 116 | return nil, fmt.Errorf("could not read directory: %w", err) 117 | } 118 | if len(files) == 0 { 119 | isNew = true 120 | } 121 | } 122 | return &Manager{ 123 | dbDirPath: dbDirPath, 124 | blockSize: blockSize, 125 | isNew: isNew, 126 | openFiles: make(map[string]*os.File), 127 | kmu: keyedMutex{}, 128 | }, nil 129 | } 130 | 131 | func (m *Manager) getFile(filename string) (*os.File, func(), error) { 132 | unlock := m.kmu.lock(filename) 133 | if f, ok := m.openFiles[filename]; ok { 134 | return f, unlock, nil 135 | } 136 | f, err := os.OpenFile(path.Join(m.dbDirPath, filename), os.O_RDWR|os.O_CREATE, 0644) 137 | if err != nil { 138 | return nil, unlock, fmt.Errorf("could not open %s: %w", filename, err) 139 | } 140 | m.openFiles[filename] = f 141 | return f, unlock, nil 142 | } 143 | 144 | func (m *Manager) Read(blk BlockID, p *Page) error { 145 | f, unlock, err := m.getFile(blk.filename) 146 | defer unlock() 147 | if err != nil { 148 | return fmt.Errorf("could not get file: %w", err) 149 | } 150 | _, err = f.ReadAt(p.bb, int64(blk.blkNum*m.blockSize)) 151 | if err != nil && err.Error() != "EOF" { 152 | return fmt.Errorf("could not read at %d: %w", blk.blkNum*m.blockSize, err) 153 | } 154 | return nil 155 | } 156 | 157 | func (m *Manager) Write(blk BlockID, p *Page) error { 158 | f, unlock, err := m.getFile(blk.filename) 159 | defer unlock() 160 | if err != nil { 161 | return fmt.Errorf("could not get file: %w", err) 162 | } 163 | _, err = f.WriteAt(p.bb, int64(blk.blkNum*m.blockSize)) 164 | if err != nil { 165 | return fmt.Errorf("could not write at %d: %w", blk.blkNum*m.blockSize, err) 166 | } 167 | return nil 168 | } 169 | 170 | func (m *Manager) Append(filename string) (BlockID, error) { 171 | f, unlock, err := m.getFile(filename) 172 | defer unlock() 173 | if err != nil { 174 | return BlockID{}, fmt.Errorf("could not get file: %w", err) 175 | } 176 | blkNum, err := m.lengthFromFile(f) 177 | if err != nil { 178 | return BlockID{}, fmt.Errorf("could not get lengthFromFile: %w", err) 179 | } 180 | blkID := NewBlockID(filename, blkNum) 181 | b := make([]byte, m.blockSize) 182 | if _, err := f.Write(b); err != nil { 183 | return BlockID{}, fmt.Errorf("could not write: %w", err) 184 | } 185 | _, err = f.WriteAt(b, int64(blkNum*m.blockSize)) 186 | if err != nil { 187 | return BlockID{}, fmt.Errorf("could not write at %d: %w", blkNum*m.blockSize, err) 188 | } 189 | return blkID, nil 190 | } 191 | 192 | func (m *Manager) Length(filename string) (int32, error) { 193 | f, unlock, err := m.getFile(filename) 194 | defer unlock() 195 | if err != nil { 196 | return 0, fmt.Errorf("could not get file: %w", err) 197 | 198 | } 199 | return m.lengthFromFile(f) 200 | } 201 | 202 | func (m *Manager) lengthFromFile(f *os.File) (int32, error) { 203 | fi, err := f.Stat() 204 | if err != nil { 205 | return 0, fmt.Errorf("could not stat file: %w", err) 206 | } 207 | return int32(fi.Size() / int64(m.blockSize)), nil 208 | } 209 | 210 | func (m *Manager) BlockSize() int32 { 211 | return m.blockSize 212 | } 213 | 214 | func (m *Manager) IsNew() bool { 215 | return m.isNew 216 | } 217 | 218 | type keyedMutex struct { 219 | mutexes sync.Map 220 | } 221 | 222 | func (m *keyedMutex) lock(key string) func() { 223 | mu, _ := m.mutexes.LoadOrStore(key, &sync.Mutex{}) 224 | mu.(*sync.Mutex).Lock() 225 | return func() { 226 | mu.(*sync.Mutex).Unlock() 227 | } 228 | } 229 | 230 | func NeedSize(target any) int32 { 231 | switch impl := target.(type) { 232 | case int32: 233 | return int32Size 234 | case string: 235 | return int32Size + int32(len(impl)) 236 | case []byte: 237 | return int32Size + int32(len(impl)) 238 | default: 239 | return 0 240 | } 241 | } 242 | -------------------------------------------------------------------------------- /internal/file/file_test.go: -------------------------------------------------------------------------------- 1 | package file 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestPage(t *testing.T) { 9 | t.Parallel() 10 | t.Run("Int32 with offset 0", func(t *testing.T) { 11 | t.Parallel() 12 | p := NewPage(128) 13 | p.SetInt32(0, 123) 14 | if p.Int32(0) != 123 { 15 | t.Errorf("expected 123, got %d", p.Int32(0)) 16 | } 17 | }) 18 | t.Run("Int32 with offset 4", func(t *testing.T) { 19 | t.Parallel() 20 | p := NewPage(128) 21 | p.SetInt32(4, 123) 22 | if p.Int32(4) != 123 { 23 | t.Errorf("expected 123, got %d", p.Int32(4)) 24 | } 25 | }) 26 | t.Run("Bytes with offset 0", func(t *testing.T) { 27 | t.Parallel() 28 | p := NewPage(128) 29 | p.SetBytes(0, []byte{1, 2, 3, 4}) 30 | if string(p.Bytes(0)) != "\x01\x02\x03\x04" { 31 | t.Errorf("expected \\x01\\x02\\x03\\x04, got %q", p.Bytes(0)) 32 | } 33 | }) 34 | t.Run("Bytes with offset 4", func(t *testing.T) { 35 | t.Parallel() 36 | p := NewPage(128) 37 | p.SetBytes(4, []byte{1, 2, 3, 4}) 38 | if string(p.Bytes(4)) != "\x01\x02\x03\x04" { 39 | t.Errorf("expected \\x01\\x02\\x03\\x04, got %q", p.Bytes(4)) 40 | } 41 | }) 42 | t.Run("Str with offset 0", func(t *testing.T) { 43 | t.Parallel() 44 | p := NewPage(128) 45 | p.SetStr(0, "abcdefghijklmn") 46 | if p.Str(0) != "abcdefghijklmn" { 47 | t.Errorf("expected abcdefghijklmn, got %s", p.Str(0)) 48 | } 49 | }) 50 | t.Run("Str with offset 4", func(t *testing.T) { 51 | t.Parallel() 52 | p := NewPage(128) 53 | p.SetStr(4, "abcdghijklmn") 54 | if p.Str(4) != "abcdghijklmn" { 55 | t.Errorf("expected abcdefghijklmn, got %s", p.Str(4)) 56 | } 57 | }) 58 | t.Run("NewPageBytes", func(t *testing.T) { 59 | t.Parallel() 60 | p := NewPageBytes([]byte("abcdefghijklmn")) 61 | if string(p.bb) != "abcdefghijklmn" { 62 | t.Errorf("expected abcdefghijklmn, got %s", p.bb) 63 | } 64 | }) 65 | } 66 | 67 | func TestFileManager(t *testing.T) { 68 | t.Parallel() 69 | t.Run("Read and write", func(t *testing.T) { 70 | t.Parallel() 71 | fm, err := NewManager(t.TempDir(), 128) 72 | if err != nil { 73 | t.Fatal(err) 74 | } 75 | 76 | blockID := NewBlockID("testfile", 0) 77 | writeP := NewPage(128) 78 | writeP.SetStr(0, "abcd") 79 | readP := NewPage(128) 80 | 81 | err = fm.Write(blockID, writeP) 82 | if err != nil { 83 | t.Fatal(err) 84 | } 85 | err = fm.Read(blockID, readP) 86 | if err != nil { 87 | t.Fatal(err) 88 | } 89 | 90 | if string(readP.Str(0)) != "abcd" { 91 | t.Errorf("expected abcd, got %s", readP.Str(0)) 92 | } 93 | }) 94 | t.Run("Read and write with offset", func(t *testing.T) { 95 | t.Parallel() 96 | fm, err := NewManager(t.TempDir(), 128) 97 | if err != nil { 98 | t.Fatal(err) 99 | } 100 | 101 | blockID := NewBlockID("testfile", 0) 102 | writeP := NewPage(128) 103 | writeP.SetStr(4, "abcd") 104 | readP := NewPage(128) 105 | 106 | err = fm.Write(blockID, writeP) 107 | if err != nil { 108 | t.Fatal(err) 109 | } 110 | err = fm.Read(blockID, readP) 111 | if err != nil { 112 | t.Fatal(err) 113 | } 114 | 115 | if string(readP.Str(4)) != "abcd" { 116 | t.Errorf("expected abcd, got %s", readP.Str(0)) 117 | } 118 | }) 119 | t.Run("Append", func(t *testing.T) { 120 | t.Parallel() 121 | fm, err := NewManager(t.TempDir(), 128) 122 | if err != nil { 123 | t.Fatal(err) 124 | } 125 | 126 | blockID := NewBlockID("testfile", 0) 127 | writeP := NewPage(128) 128 | writeP.SetStr(0, "abcd") 129 | 130 | err = fm.Write(blockID, writeP) 131 | if err != nil { 132 | t.Fatal(err) 133 | } 134 | 135 | newBlockID, err := fm.Append("testfile") 136 | if err != nil { 137 | t.Fatal(err) 138 | } 139 | if !reflect.DeepEqual(newBlockID, NewBlockID("testfile", 1)) { 140 | t.Errorf("expected empty blockID, got %v", newBlockID) 141 | } 142 | }) 143 | t.Run("Length", func(t *testing.T) { 144 | t.Parallel() 145 | fm, err := NewManager(t.TempDir(), 128) 146 | if err != nil { 147 | t.Fatal(err) 148 | } 149 | 150 | length, err := fm.Length("testfile") 151 | if err != nil { 152 | t.Fatal(err) 153 | } 154 | if length != 0 { 155 | t.Errorf("expected 0, got %d", length) 156 | } 157 | 158 | blockID := NewBlockID("testfile", 0) 159 | writeP := NewPage(128) 160 | writeP.SetStr(0, "abcd") 161 | 162 | err = fm.Write(blockID, writeP) 163 | if err != nil { 164 | t.Fatal(err) 165 | } 166 | 167 | length, err = fm.Length("testfile") 168 | if err != nil { 169 | t.Fatal(err) 170 | } 171 | if length != 1 { 172 | t.Errorf("expected 1, got %d", length) 173 | } 174 | }) 175 | } 176 | -------------------------------------------------------------------------------- /internal/index/btree_test.go: -------------------------------------------------------------------------------- 1 | package index_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/abekoh/simple-db/internal/index" 10 | "github.com/abekoh/simple-db/internal/record" 11 | "github.com/abekoh/simple-db/internal/record/schema" 12 | "github.com/abekoh/simple-db/internal/simpledb" 13 | ) 14 | 15 | func TestNewBTreeIndex_OneIndex(t *testing.T) { 16 | ctx := context.Background() 17 | db, err := simpledb.New(ctx, t.TempDir()) 18 | if err != nil { 19 | t.Fatal(err) 20 | } 21 | tx, err := db.NewTx(ctx) 22 | if err != nil { 23 | t.Fatal(err) 24 | } 25 | 26 | sche := schema.NewSchema() 27 | field := schema.NewField(schema.Varchar, 10) 28 | sche.AddField("A", field) 29 | recordLayout := record.NewLayoutSchema(sche) 30 | ts, err := record.NewTableScan(tx, "mytable", recordLayout) 31 | if err != nil { 32 | t.Fatal(err) 33 | } 34 | idxLayout := index.NewIndexLayout(field) 35 | idx, err := index.NewBTreeIndex(tx, "myindex", idxLayout) 36 | if err != nil { 37 | t.Fatal(err) 38 | } 39 | 40 | vals := make([]schema.Constant, 0) 41 | cnt := 0 42 | for range 10 { 43 | for c := 'a'; c <= 'z'; c++ { 44 | vals = append(vals, schema.ConstantStr(fmt.Sprintf("%s%d", strings.Repeat(string(c), 5), cnt))) 45 | cnt++ 46 | } 47 | for c := 'A'; c <= 'Z'; c++ { 48 | vals = append(vals, schema.ConstantStr(fmt.Sprintf("%s%d", strings.Repeat(string(c), 5), cnt))) 49 | cnt++ 50 | } 51 | } 52 | 53 | for _, val := range vals { 54 | if err := ts.Insert(); err != nil { 55 | t.Fatal(err) 56 | } 57 | if err := ts.SetVal("A", val); err != nil { 58 | t.Fatal(err) 59 | } 60 | if err := idx.Insert(val, ts.RID()); err != nil { 61 | t.Fatal(err) 62 | } 63 | idx := idx.(*index.BTreeIndex) 64 | 65 | d, err := idx.Dump() 66 | if err != nil { 67 | t.Fatal(err) 68 | } 69 | t.Logf("Val: %v, Dump: %v\n", val, d) 70 | } 71 | 72 | for _, val := range vals { 73 | if err := idx.BeforeFirst(val); err != nil { 74 | t.Fatal(err) 75 | } 76 | ok, err := idx.Next() 77 | if err != nil { 78 | t.Fatal(err) 79 | } 80 | if !ok { 81 | t.Errorf("no record found for %v", val) 82 | continue 83 | } 84 | 85 | gotRID, err := idx.DataRID() 86 | if err != nil { 87 | t.Fatal(err) 88 | } 89 | if err := ts.MoveToRID(gotRID); err != nil { 90 | t.Fatal(err) 91 | } 92 | gotVal, err := ts.Val("A") 93 | if err != nil { 94 | t.Fatal(err) 95 | } 96 | if !val.Equals(gotVal) { 97 | t.Errorf("expect %v, got %v", val, gotVal) 98 | } 99 | } 100 | 101 | if err := idx.Close(); err != nil { 102 | t.Fatal(err) 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /internal/index/index.go: -------------------------------------------------------------------------------- 1 | package index 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/abekoh/simple-db/internal/record" 7 | "github.com/abekoh/simple-db/internal/record/schema" 8 | "github.com/abekoh/simple-db/internal/transaction" 9 | ) 10 | 11 | type ( 12 | Index interface { 13 | BeforeFirst(searchKey schema.Constant) error 14 | Next() (bool, error) 15 | DataRID() (schema.RID, error) 16 | Insert(dataVal schema.Constant, dataRID schema.RID) error 17 | Delete(dataVal schema.Constant, dataRID schema.RID) error 18 | Close() error 19 | } 20 | Initializer = func(tx *transaction.Transaction, idxName string, layout *record.Layout) (Index, error) 21 | SearchCost = func(numBlocks, rpb int) int 22 | Config struct { 23 | Initializer Initializer 24 | SearchCost SearchCost 25 | } 26 | ) 27 | 28 | var ( 29 | ConfigHash = &Config{ 30 | Initializer: NewHashIndex, 31 | SearchCost: HashSearchCost, 32 | } 33 | ConfigBTree = &Config{ 34 | Initializer: NewBTreeIndex, 35 | SearchCost: BTreeSearchCost, 36 | } 37 | ) 38 | 39 | const ( 40 | blockFld = "block" 41 | idFld = "id" 42 | dataFld = "dataval" 43 | ) 44 | 45 | func NewIndexLayout(field schema.Field) *record.Layout { 46 | sche := schema.NewSchema() 47 | sche.AddInt32Field(blockFld) 48 | sche.AddInt32Field(idFld) 49 | sche.AddField(dataFld, field) 50 | return record.NewLayoutSchema(sche) 51 | } 52 | 53 | const numBuckets = 100 54 | 55 | type HashIndex struct { 56 | tx *transaction.Transaction 57 | idxName string 58 | layout *record.Layout 59 | searchKey schema.Constant 60 | tableScan *record.TableScan 61 | } 62 | 63 | func NewHashIndex(tx *transaction.Transaction, idxName string, layout *record.Layout) (Index, error) { 64 | return &HashIndex{tx: tx, idxName: idxName, layout: layout}, nil 65 | } 66 | 67 | func HashSearchCost(numBlocks, rpb int) int { 68 | return numBlocks / numBuckets 69 | } 70 | 71 | var _ Index = (*HashIndex)(nil) 72 | 73 | func (h *HashIndex) BeforeFirst(searchKey schema.Constant) error { 74 | if err := h.Close(); err != nil { 75 | return fmt.Errorf("index.Close error: %w", err) 76 | } 77 | h.searchKey = searchKey 78 | bucket := searchKey.HashCode() % numBuckets 79 | tableName := fmt.Sprintf("%s%d", h.idxName, bucket) 80 | ts, err := record.NewTableScan(h.tx, tableName, h.layout) 81 | if err != nil { 82 | return fmt.Errorf("record.NewTableScan error: %w", err) 83 | } 84 | h.tableScan = ts 85 | return nil 86 | } 87 | 88 | func (h *HashIndex) Next() (bool, error) { 89 | for { 90 | ok, err := h.tableScan.Next() 91 | if err != nil { 92 | return false, fmt.Errorf("tableScan.Next error: %w", err) 93 | } 94 | if !ok { 95 | return false, nil 96 | } 97 | val, err := h.tableScan.Val(dataFld) 98 | if err != nil { 99 | return false, fmt.Errorf("tableScan.Val error: %w", err) 100 | } 101 | if h.searchKey.Equals(val) { 102 | return true, nil 103 | } 104 | } 105 | } 106 | 107 | func (h *HashIndex) DataRID() (schema.RID, error) { 108 | blockNum, err := h.tableScan.Int32(blockFld) 109 | if err != nil { 110 | return schema.RID{}, fmt.Errorf("tableScan.Int32 error: %w", err) 111 | } 112 | id, err := h.tableScan.Int32(idFld) 113 | if err != nil { 114 | return schema.RID{}, fmt.Errorf("tableScan.Int32 error: %w", err) 115 | } 116 | return schema.NewRID(blockNum, id), nil 117 | } 118 | 119 | func (h *HashIndex) Insert(dataVal schema.Constant, dataRID schema.RID) error { 120 | if err := h.BeforeFirst(dataVal); err != nil { 121 | return fmt.Errorf("index.BeforeFirst error: %w", err) 122 | } 123 | if err := h.tableScan.Insert(); err != nil { 124 | return fmt.Errorf("index.Insert error: %w", err) 125 | } 126 | if err := h.tableScan.SetInt32(blockFld, dataRID.BlockNum()); err != nil { 127 | return fmt.Errorf("tableScan.SetInt32 error: %w", err) 128 | } 129 | if err := h.tableScan.SetInt32(idFld, dataRID.Slot()); err != nil { 130 | return fmt.Errorf("tableScan.SetInt32 error: %w", err) 131 | } 132 | if err := h.tableScan.SetVal(dataFld, dataVal); err != nil { 133 | return fmt.Errorf("tableScan.SetVal error: %w", err) 134 | } 135 | return nil 136 | } 137 | 138 | func (h *HashIndex) Delete(dataVal schema.Constant, dataRID schema.RID) error { 139 | if err := h.BeforeFirst(dataVal); err != nil { 140 | return fmt.Errorf("index.BeforeFirst error: %w", err) 141 | } 142 | for { 143 | ok, err := h.Next() 144 | if err != nil { 145 | return fmt.Errorf("index.Next error: %w", err) 146 | } 147 | if !ok { 148 | return nil 149 | } 150 | rid, err := h.DataRID() 151 | if err != nil { 152 | return fmt.Errorf("index.DataRID error: %w", err) 153 | } 154 | if rid.Equals(dataRID) { 155 | if err := h.tableScan.Delete(); err != nil { 156 | return fmt.Errorf("tableScan.Delete error: %w", err) 157 | } 158 | return nil 159 | } 160 | } 161 | } 162 | 163 | func (h *HashIndex) Close() error { 164 | if h.tableScan != nil { 165 | if err := h.tableScan.Close(); err != nil { 166 | return fmt.Errorf("tableScan.Close error: %w", err) 167 | } 168 | } 169 | return nil 170 | } 171 | -------------------------------------------------------------------------------- /internal/index/scan.go: -------------------------------------------------------------------------------- 1 | package index 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/abekoh/simple-db/internal/query" 7 | "github.com/abekoh/simple-db/internal/record" 8 | "github.com/abekoh/simple-db/internal/record/schema" 9 | ) 10 | 11 | type SelectScan struct { 12 | tableScan *record.TableScan 13 | idx Index 14 | val schema.Constant 15 | } 16 | 17 | func NewSelectScan(tableScan *record.TableScan, idx Index, val schema.Constant) (*SelectScan, error) { 18 | s := &SelectScan{tableScan: tableScan, idx: idx, val: val} 19 | if err := s.BeforeFirst(); err != nil { 20 | return nil, fmt.Errorf("BeforeFirst error: %w", err) 21 | } 22 | return s, nil 23 | } 24 | 25 | var _ query.Scan = (*SelectScan)(nil) 26 | 27 | func (s SelectScan) Val(fieldName schema.FieldName) (schema.Constant, error) { 28 | val, err := s.tableScan.Val(fieldName) 29 | if err != nil { 30 | return nil, fmt.Errorf("tableScan.Val error: %w", err) 31 | } 32 | return val, nil 33 | } 34 | 35 | func (s SelectScan) BeforeFirst() error { 36 | if err := s.idx.BeforeFirst(s.val); err != nil { 37 | return fmt.Errorf("index.BeforeFirst error: %w", err) 38 | } 39 | return nil 40 | } 41 | 42 | func (s SelectScan) Next() (bool, error) { 43 | ok, err := s.idx.Next() 44 | if err != nil { 45 | return false, fmt.Errorf("index.Next error: %w", err) 46 | } 47 | if ok { 48 | rid, err := s.idx.DataRID() 49 | if err != nil { 50 | return false, fmt.Errorf("index.DataRID error: %w", err) 51 | } 52 | if err := s.tableScan.MoveToRID(rid); err != nil { 53 | return false, fmt.Errorf("tableScan.MoveToRID error: %w", err) 54 | } 55 | } 56 | return ok, nil 57 | } 58 | 59 | func (s SelectScan) Int32(fieldName schema.FieldName) (int32, error) { 60 | val, err := s.tableScan.Int32(fieldName) 61 | if err != nil { 62 | return 0, fmt.Errorf("tableScan.Int32 error: %w", err) 63 | } 64 | return val, nil 65 | } 66 | 67 | func (s SelectScan) Str(fieldName schema.FieldName) (string, error) { 68 | val, err := s.tableScan.Str(fieldName) 69 | if err != nil { 70 | return "", fmt.Errorf("tableScan.Str error: %w", err) 71 | } 72 | return val, nil 73 | } 74 | 75 | func (s SelectScan) HasField(fieldName schema.FieldName) bool { 76 | return s.tableScan.HasField(fieldName) 77 | } 78 | 79 | func (s SelectScan) Close() error { 80 | if err := s.idx.Close(); err != nil { 81 | return fmt.Errorf("index.Close error: %w", err) 82 | } 83 | if err := s.tableScan.Close(); err != nil { 84 | return fmt.Errorf("tableScan.Close error: %w", err) 85 | } 86 | return nil 87 | } 88 | 89 | type JoinScan struct { 90 | lhs query.Scan 91 | rhs *record.TableScan 92 | idx Index 93 | joinField schema.FieldName 94 | } 95 | 96 | func NewJoinScan(lhs query.Scan, rhs *record.TableScan, idx Index, joinField schema.FieldName) (*JoinScan, error) { 97 | js := &JoinScan{lhs: lhs, idx: idx, joinField: joinField, rhs: rhs} 98 | if err := js.BeforeFirst(); err != nil { 99 | return nil, fmt.Errorf("BeforeFirst error: %w", err) 100 | } 101 | return js, nil 102 | } 103 | 104 | var _ query.Scan = (*JoinScan)(nil) 105 | 106 | func (j JoinScan) Val(fieldName schema.FieldName) (schema.Constant, error) { 107 | if j.lhs.HasField(fieldName) { 108 | if val, err := j.lhs.Val(fieldName); err != nil { 109 | return nil, fmt.Errorf("lhs.Val error: %w", err) 110 | } else { 111 | return val, nil 112 | } 113 | } else { 114 | if val, err := j.rhs.Val(fieldName); err != nil { 115 | return nil, fmt.Errorf("rhs.Val error: %w", err) 116 | } else { 117 | return val, nil 118 | } 119 | } 120 | } 121 | 122 | func (j JoinScan) BeforeFirst() error { 123 | if err := j.lhs.BeforeFirst(); err != nil { 124 | return fmt.Errorf("lhs.BeforeFirst error: %w", err) 125 | } 126 | if _, err := j.lhs.Next(); err != nil { 127 | return fmt.Errorf("lhs.Next error: %w", err) 128 | } 129 | if err := j.resetIndex(); err != nil { 130 | return fmt.Errorf("resetIndex error: %w", err) 131 | } 132 | return nil 133 | } 134 | 135 | func (j JoinScan) Next() (bool, error) { 136 | for { 137 | idxOk, err := j.idx.Next() 138 | if err != nil { 139 | return false, fmt.Errorf("index.Next error: %w", err) 140 | } 141 | if idxOk { 142 | rid, err := j.idx.DataRID() 143 | if err != nil { 144 | return false, fmt.Errorf("index.DataRID error: %w", err) 145 | } 146 | if err := j.rhs.MoveToRID(rid); err != nil { 147 | return false, fmt.Errorf("rhs.MoveToRID error: %w", err) 148 | } 149 | return true, nil 150 | } 151 | lhsOk, err := j.lhs.Next() 152 | if err != nil { 153 | return false, fmt.Errorf("lhs.Next error: %w", err) 154 | } 155 | if !lhsOk { 156 | return false, nil 157 | } 158 | if err := j.resetIndex(); err != nil { 159 | return false, fmt.Errorf("resetIndex error: %w", err) 160 | } 161 | } 162 | } 163 | 164 | func (j JoinScan) Int32(fieldName schema.FieldName) (int32, error) { 165 | if j.lhs.HasField(fieldName) { 166 | if val, err := j.lhs.Int32(fieldName); err != nil { 167 | return 0, fmt.Errorf("lhs.Int32 error: %w", err) 168 | } else { 169 | return val, nil 170 | } 171 | } else { 172 | if val, err := j.rhs.Int32(fieldName); err != nil { 173 | return 0, fmt.Errorf("rhs.Int32 error: %w", err) 174 | } else { 175 | return val, nil 176 | } 177 | } 178 | } 179 | 180 | func (j JoinScan) Str(fieldName schema.FieldName) (string, error) { 181 | if j.lhs.HasField(fieldName) { 182 | if val, err := j.lhs.Str(fieldName); err != nil { 183 | return "", fmt.Errorf("lhs.Str error: %w", err) 184 | } else { 185 | return val, nil 186 | } 187 | } else { 188 | if val, err := j.rhs.Str(fieldName); err != nil { 189 | return "", fmt.Errorf("rhs.Str error: %w", err) 190 | } else { 191 | return val, nil 192 | } 193 | } 194 | } 195 | 196 | func (j JoinScan) HasField(fieldName schema.FieldName) bool { 197 | return j.lhs.HasField(fieldName) || j.rhs.HasField(fieldName) 198 | } 199 | 200 | func (j JoinScan) Close() error { 201 | if err := j.lhs.Close(); err != nil { 202 | return fmt.Errorf("lhs.Close error: %w", err) 203 | } 204 | if err := j.idx.Close(); err != nil { 205 | return fmt.Errorf("index.Close error: %w", err) 206 | } 207 | if err := j.rhs.Close(); err != nil { 208 | return fmt.Errorf("rhs.Close error: %w", err) 209 | } 210 | return nil 211 | } 212 | 213 | func (j JoinScan) resetIndex() error { 214 | searchKey, err := j.lhs.Val(j.joinField) 215 | if err != nil { 216 | return fmt.Errorf("lhs.Val error: %w", err) 217 | } 218 | if err := j.idx.BeforeFirst(searchKey); err != nil { 219 | return fmt.Errorf("index.BeforeFirst error: %w", err) 220 | } 221 | return nil 222 | } 223 | -------------------------------------------------------------------------------- /internal/index/scan_test.go: -------------------------------------------------------------------------------- 1 | package index_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/abekoh/simple-db/internal/index" 10 | "github.com/abekoh/simple-db/internal/query" 11 | "github.com/abekoh/simple-db/internal/record" 12 | "github.com/abekoh/simple-db/internal/record/schema" 13 | "github.com/abekoh/simple-db/internal/simpledb" 14 | ) 15 | 16 | func TestIndexScan(t *testing.T) { 17 | type test struct { 18 | name string 19 | cfg *index.Config 20 | } 21 | for _, tt := range []test{ 22 | {name: "Hash", cfg: index.ConfigHash}, 23 | {name: "BTree", cfg: index.ConfigBTree}, 24 | } { 25 | cfg := tt.cfg 26 | t.Run("TableScan -> IndexSelectScan -> ProjectScan", func(t *testing.T) { 27 | ctx := context.Background() 28 | db, err := simpledb.New(ctx, t.TempDir()) 29 | if err != nil { 30 | t.Fatal(err) 31 | } 32 | tx, err := db.NewTx(ctx) 33 | if err != nil { 34 | t.Fatal(err) 35 | } 36 | 37 | sche := schema.NewSchema() 38 | sche.AddInt32Field("A") 39 | fieldB := schema.NewField(schema.Varchar, 9) 40 | sche.AddField("B", fieldB) 41 | layout := record.NewLayoutSchema(sche) 42 | scan1, err := record.NewTableScan(tx, "T", layout) 43 | if err != nil { 44 | t.Fatal(err) 45 | } 46 | 47 | idxLayout := index.NewIndexLayout(fieldB) 48 | idx1, err := cfg.Initializer(tx, "I", idxLayout) 49 | if err != nil { 50 | t.Fatal(err) 51 | } 52 | 53 | if err := scan1.BeforeFirst(); err != nil { 54 | t.Fatal(err) 55 | } 56 | for i := 0; i < 200; i++ { 57 | if err := scan1.Insert(); err != nil { 58 | t.Fatal(err) 59 | } 60 | rec := schema.ConstantStr(fmt.Sprintf("rec%d", i)) 61 | if err := scan1.SetVal("B", rec); err != nil { 62 | t.Fatal(err) 63 | } 64 | if err := idx1.Insert(rec, scan1.RID()); err != nil { 65 | t.Fatal(err) 66 | } 67 | } 68 | if err := scan1.Close(); err != nil { 69 | t.Fatal(err) 70 | } 71 | if err := idx1.Close(); err != nil { 72 | t.Fatal(err) 73 | } 74 | 75 | scan2, err := record.NewTableScan(tx, "T", layout) 76 | if err != nil { 77 | t.Fatal(err) 78 | } 79 | idx2, err := cfg.Initializer(tx, "I", idxLayout) 80 | if err != nil { 81 | t.Fatal(err) 82 | } 83 | scan3, err := index.NewSelectScan(scan2, idx2, schema.ConstantStr("rec100")) 84 | if err != nil { 85 | t.Fatal(err) 86 | } 87 | scan4 := query.NewProjectScan(scan3, "B") 88 | 89 | ok, err := scan4.Next() 90 | if err != nil { 91 | t.Fatal(err) 92 | } 93 | if !ok { 94 | t.Fatal("no records") 95 | } 96 | b, err := scan4.Str("B") 97 | if err != nil { 98 | t.Fatal(err) 99 | } 100 | if b != "rec100" { 101 | t.Fatalf("unexpected value: %s", b) 102 | } 103 | if err := tx.Commit(); err != nil { 104 | t.Fatal(err) 105 | } 106 | }) 107 | t.Run("TableScan*2 -> IndexProductScan -> ProjectScan", func(t *testing.T) { 108 | ctx := context.Background() 109 | db, err := simpledb.New(ctx, t.TempDir()) 110 | if err != nil { 111 | t.Fatal(err) 112 | } 113 | tx, err := db.NewTx(ctx) 114 | if err != nil { 115 | t.Fatal(err) 116 | } 117 | 118 | lhsSche := schema.NewSchema() 119 | lhsSche.AddStrField("A", 9) 120 | fieldB := schema.NewField(schema.Varchar, 9) 121 | lhsSche.AddField("B", fieldB) 122 | lhsLayout := record.NewLayoutSchema(lhsSche) 123 | lhsTableS1, err := record.NewTableScan(tx, "T1", lhsLayout) 124 | if err != nil { 125 | t.Fatal(err) 126 | } 127 | if err := lhsTableS1.BeforeFirst(); err != nil { 128 | t.Fatal(err) 129 | } 130 | n := 5 131 | for i := 0; i < n; i++ { 132 | if err := lhsTableS1.Insert(); err != nil { 133 | t.Fatal(err) 134 | } 135 | if err := lhsTableS1.SetStr("A", fmt.Sprintf("aaa%d", i)); err != nil { 136 | t.Fatal(err) 137 | } 138 | if err := lhsTableS1.SetStr("B", fmt.Sprintf("bbb%d", i)); err != nil { 139 | t.Fatal(err) 140 | } 141 | } 142 | if err := lhsTableS1.Close(); err != nil { 143 | t.Fatal(err) 144 | } 145 | 146 | rhsSche := schema.NewSchema() 147 | rhsSche.AddField("B", fieldB) 148 | rhsSche.AddStrField("C", 9) 149 | rhsLayout := record.NewLayoutSchema(rhsSche) 150 | rhsTableS1, err := record.NewTableScan(tx, "T2", rhsLayout) 151 | if err != nil { 152 | t.Fatal(err) 153 | } 154 | idxLayout := index.NewIndexLayout(fieldB) 155 | rhsIdx1, err := cfg.Initializer(tx, "I", idxLayout) 156 | if err != nil { 157 | t.Fatal(err) 158 | } 159 | if err := rhsTableS1.BeforeFirst(); err != nil { 160 | t.Fatal(err) 161 | } 162 | for i := 0; i < n+5; i++ { 163 | if err := rhsTableS1.Insert(); err != nil { 164 | t.Fatal(err) 165 | } 166 | rec := schema.ConstantStr(fmt.Sprintf("bbb%d", i)) 167 | if err := rhsTableS1.SetVal("B", rec); err != nil { 168 | t.Fatal(err) 169 | } 170 | if err := rhsIdx1.Insert(rec, rhsTableS1.RID()); err != nil { 171 | t.Fatal(err) 172 | } 173 | if err := rhsTableS1.SetStr("C", fmt.Sprintf("ccc%d", i)); err != nil { 174 | t.Fatal(err) 175 | } 176 | } 177 | if err := rhsTableS1.Close(); err != nil { 178 | t.Fatal(err) 179 | } 180 | if err := rhsIdx1.Close(); err != nil { 181 | t.Fatal(err) 182 | } 183 | 184 | lhsTableS2, err := record.NewTableScan(tx, "T1", lhsLayout) 185 | if err != nil { 186 | t.Fatal(err) 187 | } 188 | rhsTableS2, err := record.NewTableScan(tx, "T2", rhsLayout) 189 | if err != nil { 190 | t.Fatal(err) 191 | } 192 | idx2, err := cfg.Initializer(tx, "I", idxLayout) 193 | if err != nil { 194 | t.Fatal(err) 195 | } 196 | joinS, err := index.NewJoinScan(lhsTableS2, rhsTableS2, idx2, "B") 197 | if err != nil { 198 | t.Fatal(err) 199 | } 200 | prjS := query.NewProjectScan(joinS, "A", "B", "C") 201 | got := make([]string, 0, n) 202 | for { 203 | ok, err := prjS.Next() 204 | if err != nil { 205 | t.Fatal(err) 206 | } 207 | if !ok { 208 | break 209 | } 210 | a, err := prjS.Str("A") 211 | if err != nil { 212 | t.Fatal(err) 213 | } 214 | b, err := prjS.Str("B") 215 | if err != nil { 216 | t.Fatal(err) 217 | } 218 | c, err := prjS.Str("C") 219 | if err != nil { 220 | t.Fatal(err) 221 | } 222 | got = append(got, fmt.Sprintf("%s, %s, %s", a, b, c)) 223 | } 224 | if len(got) != n { 225 | t.Errorf("got %d, want %d", len(got), n) 226 | } 227 | expected := `aaa0, bbb0, ccc0 228 | aaa1, bbb1, ccc1 229 | aaa2, bbb2, ccc2 230 | aaa3, bbb3, ccc3 231 | aaa4, bbb4, ccc4` 232 | if strings.Join(got, "\n") != expected { 233 | t.Errorf("got %s, want %s", strings.Join(got, "\n"), expected) 234 | } 235 | if err := prjS.Close(); err != nil { 236 | t.Fatal(err) 237 | } 238 | if err := tx.Commit(); err != nil { 239 | t.Fatal(err) 240 | } 241 | }) 242 | } 243 | } 244 | -------------------------------------------------------------------------------- /internal/log/log.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | 7 | "github.com/abekoh/simple-db/internal/file" 8 | ) 9 | 10 | type SequenceNumber int32 11 | 12 | type Manager struct { 13 | fm *file.Manager 14 | filename string 15 | page *file.Page 16 | currentBlockID file.BlockID 17 | latestLSN SequenceNumber 18 | lastSavedLSN SequenceNumber 19 | appendMu sync.Mutex 20 | } 21 | 22 | func NewManager(fm *file.Manager, filename string) (*Manager, error) { 23 | p := file.NewPage(fm.BlockSize()) 24 | m := &Manager{ 25 | fm: fm, 26 | filename: filename, 27 | page: p, 28 | } 29 | logSize, err := fm.Length(filename) 30 | if err != nil { 31 | return nil, fmt.Errorf("could not get length: %w", err) 32 | } 33 | if logSize == 0 { 34 | currentBlockID, err := m.appendNewBlock() 35 | if err != nil { 36 | return nil, fmt.Errorf("could not append new block: %w", err) 37 | } 38 | m.currentBlockID = currentBlockID 39 | } else { 40 | m.currentBlockID = file.NewBlockID(filename, logSize-1) 41 | err = fm.Read(m.currentBlockID, m.page) 42 | if err != nil { 43 | return nil, fmt.Errorf("could not read: %w", err) 44 | } 45 | } 46 | return m, nil 47 | } 48 | 49 | func (m *Manager) Flush(lsn SequenceNumber) error { 50 | if lsn >= m.latestLSN { 51 | return m.flush() 52 | } 53 | return nil 54 | } 55 | 56 | func (m *Manager) Append(rec []byte) (lsn SequenceNumber, err error) { 57 | m.appendMu.Lock() 58 | defer m.appendMu.Unlock() 59 | boundary := m.page.Int32(0) 60 | recNeedSize := file.NeedSize(rec) 61 | if boundary-recNeedSize < 4 { 62 | if err := m.flush(); err != nil { 63 | return 0, fmt.Errorf("could not flush: %w", err) 64 | } 65 | m.currentBlockID, err = m.appendNewBlock() 66 | if err != nil { 67 | return 0, fmt.Errorf("could not append new block: %w", err) 68 | } 69 | boundary = m.page.Int32(0) 70 | } 71 | recPos := boundary - recNeedSize 72 | m.page.SetBytes(recPos, rec) 73 | m.page.SetInt32(0, recPos) 74 | m.latestLSN += 1 75 | return m.latestLSN, nil 76 | } 77 | 78 | func (m *Manager) Iterator() func(func([]byte) bool) { 79 | m.flush() 80 | p := file.NewPage(m.fm.BlockSize()) 81 | blockID := m.currentBlockID 82 | var currentPos int32 83 | moveToBlock := func(blkID file.BlockID) bool { 84 | err := m.fm.Read(blkID, p) 85 | if err != nil { 86 | return false 87 | } 88 | currentPos = p.Int32(0) // boundary 89 | return true 90 | } 91 | moveToBlock(blockID) 92 | return func(yield func([]byte) bool) { 93 | for { 94 | if currentPos >= m.fm.BlockSize() && blockID.Num() <= 0 { 95 | return 96 | } 97 | if currentPos == m.fm.BlockSize() { 98 | blockID = file.NewBlockID(blockID.Filename(), blockID.Num()-1) 99 | if !moveToBlock(blockID) { 100 | return 101 | } 102 | } 103 | rec := p.Bytes(currentPos) 104 | currentPos += 4 + int32(len(rec)) 105 | if !yield(rec) { 106 | return 107 | } 108 | } 109 | } 110 | } 111 | 112 | func (m *Manager) appendNewBlock() (file.BlockID, error) { 113 | blockID, err := m.fm.Append(m.filename) 114 | if err != nil { 115 | return file.BlockID{}, fmt.Errorf("could not append: %w", err) 116 | } 117 | m.page.SetInt32(0, m.fm.BlockSize()) 118 | if err := m.fm.Write(blockID, m.page); err != nil { 119 | return file.BlockID{}, fmt.Errorf("could not write: %w", err) 120 | } 121 | return blockID, nil 122 | } 123 | 124 | func (m *Manager) flush() error { 125 | if err := m.fm.Write(m.currentBlockID, m.page); err != nil { 126 | return fmt.Errorf("could not write: %w", err) 127 | } 128 | m.lastSavedLSN = m.latestLSN 129 | return nil 130 | } 131 | -------------------------------------------------------------------------------- /internal/log/log_test.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/abekoh/simple-db/internal/file" 9 | ) 10 | 11 | func TestManager(t *testing.T) { 12 | t.Parallel() 13 | t.Run("Append", func(t *testing.T) { 14 | t.Parallel() 15 | fm, err := file.NewManager(t.TempDir(), 128) 16 | if err != nil { 17 | t.Fatal(err) 18 | } 19 | lm, err := NewManager(fm, "logfile") 20 | if err != nil { 21 | t.Fatal(err) 22 | } 23 | lsn, err := lm.Append([]byte("abcd")) 24 | if err != nil { 25 | t.Fatal(err) 26 | } 27 | if lsn != 1 { 28 | t.Errorf("expected 0, got %d", lsn) 29 | } 30 | var logs []string 31 | for r := range lm.Iterator() { 32 | logs = append(logs, string(r)) 33 | } 34 | if !reflect.DeepEqual(logs, []string{"abcd"}) { 35 | t.Errorf("expected [abcd], got %v", logs) 36 | } 37 | }) 38 | t.Run("Append many", func(t *testing.T) { 39 | t.Parallel() 40 | fm, err := file.NewManager(t.TempDir(), 128) 41 | if err != nil { 42 | t.Fatal(err) 43 | } 44 | lm, err := NewManager(fm, "logfile") 45 | if err != nil { 46 | t.Fatal(err) 47 | } 48 | for i := 0; i < 100000; i++ { 49 | _, err := lm.Append([]byte(fmt.Sprintf("%04d", i))) 50 | if err != nil { 51 | t.Fatal(err) 52 | } 53 | } 54 | count := 0 55 | for r := range lm.Iterator() { 56 | if string(r) != fmt.Sprintf("%04d", 100000-count-1) { 57 | t.Errorf("expected %04d, got %s", 100000-count-1, r) 58 | } 59 | count++ 60 | } 61 | if count != 100000 { 62 | t.Errorf("expected 100000, got %d", count) 63 | } 64 | }) 65 | } 66 | -------------------------------------------------------------------------------- /internal/metadata/metadata_test.go: -------------------------------------------------------------------------------- 1 | package metadata 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "reflect" 7 | "slices" 8 | "testing" 9 | 10 | "github.com/abekoh/simple-db/internal/buffer" 11 | "github.com/abekoh/simple-db/internal/file" 12 | "github.com/abekoh/simple-db/internal/log" 13 | "github.com/abekoh/simple-db/internal/record" 14 | schema2 "github.com/abekoh/simple-db/internal/record/schema" 15 | "github.com/abekoh/simple-db/internal/transaction" 16 | ) 17 | 18 | func TestMetadataManager(t *testing.T) { 19 | t.Parallel() 20 | 21 | fm, err := file.NewManager(t.TempDir(), 400) 22 | if err != nil { 23 | t.Fatal(err) 24 | } 25 | lm, err := log.NewManager(fm, "logfile") 26 | if err != nil { 27 | t.Fatal(err) 28 | } 29 | ctx := context.Background() 30 | bm := buffer.NewManager(ctx, fm, lm, 8) 31 | 32 | tx, err := transaction.NewTransaction(ctx, bm, fm, lm) 33 | if err != nil { 34 | t.Fatal(err) 35 | } 36 | 37 | m, err := NewManager(true, tx, nil) 38 | if err != nil { 39 | t.Fatal(err) 40 | } 41 | 42 | schema := schema2.NewSchema() 43 | schema.AddInt32Field("A") 44 | schema.AddStrField("B", 9) 45 | if err := m.CreateTable("MyTable", schema, tx); err != nil { 46 | t.Fatal(err) 47 | } 48 | 49 | // TableManager 50 | layout, err := m.Layout("MyTable", tx) 51 | if err != nil { 52 | t.Fatal(err) 53 | } 54 | if layout.SlotSize() != 21 { 55 | t.Errorf("expected 21, got %d", layout.SlotSize()) 56 | } 57 | var fields []string 58 | for _, fieldName := range layout.Schema().FieldNames() { 59 | switch layout.Schema().Typ(fieldName) { 60 | case schema2.Integer32: 61 | fields = append(fields, fmt.Sprintf("%s: int", fieldName)) 62 | case schema2.Varchar: 63 | fields = append(fields, fmt.Sprintf("%s: varchar(%d)", fieldName, layout.Schema().Length(fieldName))) 64 | } 65 | } 66 | slices.Sort(fields) 67 | if !reflect.DeepEqual(fields, []string{"A: int", "B: varchar(9)"}) { 68 | t.Errorf("expected [A: int, B: varchar(9)], got %v", fields) 69 | } 70 | 71 | if tx.Commit() != nil { 72 | t.Fatal(err) 73 | } 74 | 75 | // StatManager 76 | scan, err := record.NewTableScan(tx, "MyTable", layout) 77 | if err != nil { 78 | t.Fatal(err) 79 | } 80 | for i := 0; i < 50; i++ { 81 | if err := scan.Insert(); err != nil { 82 | t.Fatal(err) 83 | } 84 | if err := scan.SetInt32("A", int32(i)); err != nil { 85 | t.Fatal(err) 86 | } 87 | if err := scan.SetStr("B", fmt.Sprintf("rec%d", i)); err != nil { 88 | t.Fatal(err) 89 | } 90 | } 91 | statInfo, err := m.StatInfo("MyTable", layout, tx) 92 | if err != nil { 93 | t.Fatal(err) 94 | } 95 | if statInfo.BlocksAccessed() != 3 { 96 | t.Errorf("expected 3, got %d", statInfo.BlocksAccessed()) 97 | } 98 | if statInfo.RecordsOutput() != 50 { 99 | t.Errorf("expected 50, got %d", statInfo.RecordsOutput()) 100 | } 101 | if statInfo.DistinctValues("A") != 17 { 102 | t.Errorf("expected 17, got %d", statInfo.DistinctValues("A")) 103 | } 104 | if statInfo.DistinctValues("B") != 17 { 105 | t.Errorf("expected 17, got %d", statInfo.DistinctValues("B")) 106 | } 107 | 108 | // ViewManager 109 | viewDef := "SELECT B FROM MyTable WHERE A = 1" 110 | if err := m.CreateView("MyView", viewDef, tx); err != nil { 111 | t.Fatal(err) 112 | } 113 | gotViewDef, ok, err := m.ViewDef("MyView", tx) 114 | if err != nil { 115 | t.Fatal(err) 116 | } 117 | if !ok { 118 | t.Fatal("expected true, got false") 119 | } 120 | if gotViewDef != viewDef { 121 | t.Errorf("expected %s, got %s", "SELECT B FROM MyTable WHERE A = 1", gotViewDef) 122 | } 123 | 124 | // IndexManager 125 | if err := m.CreateIndex("IndexA", "MyTable", "A", tx); err != nil { 126 | t.Fatal(err) 127 | } 128 | if err := m.CreateIndex("IndexB", "MyTable", "B", tx); err != nil { 129 | t.Fatal(err) 130 | } 131 | indexMap, err := m.IndexInfo("MyTable", tx) 132 | if err != nil { 133 | t.Fatal(err) 134 | } 135 | if len(indexMap) != 2 { 136 | t.Errorf("expected 2, got %d", len(indexMap)) 137 | } 138 | // TODO: assert indexMap 139 | } 140 | -------------------------------------------------------------------------------- /internal/parse/lexer.go: -------------------------------------------------------------------------------- 1 | package parse 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | type tokenType string 9 | 10 | const ( 11 | identifier tokenType = "IDENTIFIER" 12 | 13 | selectTok tokenType = "SELECT" 14 | from tokenType = "FROM" 15 | and tokenType = "AND" 16 | where tokenType = "WHERE" 17 | insert tokenType = "INSERT" 18 | into tokenType = "INTO" 19 | values tokenType = "VALUES" 20 | deleteTok tokenType = "DELETE" 21 | update tokenType = "UPDATE" 22 | set tokenType = "SET" 23 | create tokenType = "CREATE" 24 | table tokenType = "TABLE" 25 | intTok tokenType = "INT" 26 | varchar tokenType = "VARCHAR" 27 | view tokenType = "VIEW" 28 | as tokenType = "AS" 29 | index tokenType = "INDEX" 30 | join tokenType = "JOIN" 31 | on tokenType = "ON" 32 | placeholder tokenType = "PLACEHOLDER" 33 | begin tokenType = "BEGIN" 34 | commit tokenType = "COMMIT" 35 | rollback tokenType = "ROLLBACK" 36 | order tokenType = "ORDER" 37 | group tokenType = "GROUP" 38 | by tokenType = "BY" 39 | asc tokenType = "ASC" 40 | desc tokenType = "DESC" 41 | explain tokenType = "EXPLAIN" 42 | 43 | number tokenType = "NUMBER" 44 | stringTok tokenType = "STRING" 45 | 46 | equal tokenType = "EQUAL" 47 | comma tokenType = "COMMA" 48 | lparen tokenType = "LPAREN" 49 | rparen tokenType = "RPAREN" 50 | 51 | illegal tokenType = "ILLEGAL" 52 | 53 | max tokenType = "MAX" 54 | min tokenType = "MIN" 55 | sum tokenType = "SUM" 56 | count tokenType = "COUNT" 57 | 58 | asterisk tokenType = "*" 59 | 60 | eof tokenType = "EOF" 61 | ) 62 | 63 | var keywords = map[tokenType]struct{}{ 64 | selectTok: {}, 65 | from: {}, 66 | and: {}, 67 | where: {}, 68 | insert: {}, 69 | into: {}, 70 | values: {}, 71 | deleteTok: {}, 72 | update: {}, 73 | set: {}, 74 | create: {}, 75 | table: {}, 76 | intTok: {}, 77 | varchar: {}, 78 | view: {}, 79 | as: {}, 80 | index: {}, 81 | join: {}, 82 | on: {}, 83 | begin: {}, 84 | commit: {}, 85 | rollback: {}, 86 | order: {}, 87 | group: {}, 88 | by: {}, 89 | asc: {}, 90 | desc: {}, 91 | max: {}, 92 | min: {}, 93 | sum: {}, 94 | count: {}, 95 | asterisk: {}, 96 | explain: {}, 97 | } 98 | 99 | func lookupToken(ident string) tokenType { 100 | ident = strings.ToUpper(ident) 101 | if _, ok := keywords[tokenType(ident)]; ok { 102 | return tokenType(ident) 103 | } 104 | return identifier 105 | } 106 | 107 | type token struct { 108 | typ tokenType 109 | literal string 110 | } 111 | 112 | type Lexer struct { 113 | s string 114 | cursor int 115 | readCursor int 116 | char byte 117 | } 118 | 119 | func NewLexer(s string) *Lexer { 120 | l := &Lexer{s: s} 121 | l.readChar() 122 | return l 123 | } 124 | 125 | func (l *Lexer) NextToken() token { 126 | for l.char == ' ' || l.char == '\t' || l.char == '\n' || l.char == '\r' { 127 | l.readChar() 128 | } 129 | 130 | switch l.char { 131 | case '*': 132 | l.readChar() 133 | return token{typ: asterisk, literal: "*"} 134 | case '=': 135 | l.readChar() 136 | return token{typ: equal, literal: "="} 137 | case ',': 138 | l.readChar() 139 | return token{typ: comma, literal: ","} 140 | case '(': 141 | l.readChar() 142 | return token{typ: lparen, literal: "("} 143 | case ')': 144 | l.readChar() 145 | return token{typ: rparen, literal: ")"} 146 | case '\'': 147 | tok := token{typ: stringTok, literal: l.readString()} 148 | l.readChar() 149 | return tok 150 | case '$': 151 | l.readChar() 152 | return token{typ: placeholder, literal: fmt.Sprintf("$%s", l.readNumber())} 153 | case 0: 154 | return token{typ: eof, literal: ""} 155 | default: 156 | if isLetter(l.char) { 157 | ident := l.readIdentifier() 158 | typ := lookupToken(ident) 159 | return token{typ: typ, literal: ident} 160 | } else if isDigit(l.char) { 161 | return token{typ: number, literal: l.readNumber()} 162 | } else { 163 | return token{typ: illegal, literal: string(l.char)} 164 | } 165 | } 166 | } 167 | 168 | func (l *Lexer) TokenIterator() func(func(token) bool) { 169 | return func(yield func(token) bool) { 170 | for { 171 | tok := l.NextToken() 172 | if !yield(tok) { 173 | break 174 | } 175 | if tok.typ == eof { 176 | break 177 | } 178 | } 179 | } 180 | } 181 | 182 | func (l *Lexer) readChar() { 183 | if l.readCursor >= len(l.s) { 184 | l.char = 0 185 | } else { 186 | l.char = l.s[l.readCursor] 187 | } 188 | l.cursor = l.readCursor 189 | l.readCursor++ 190 | } 191 | 192 | func (l *Lexer) readIdentifier() string { 193 | start := l.cursor 194 | for isLetter(l.char) || isDigit(l.char) || l.char == '.' { 195 | l.readChar() 196 | } 197 | return l.s[start:l.cursor] 198 | } 199 | 200 | func (l *Lexer) readString() string { 201 | start := l.cursor + 1 202 | for { 203 | l.readChar() 204 | if l.char == '\'' || l.char == 0 { 205 | break 206 | } 207 | } 208 | return l.s[start:l.cursor] 209 | } 210 | 211 | func (l *Lexer) readNumber() string { 212 | start := l.cursor 213 | for isDigit(l.char) { 214 | l.readChar() 215 | } 216 | return l.s[start:l.cursor] 217 | } 218 | 219 | func isLetter(char byte) bool { 220 | return 'a' <= char && char <= 'z' || 'A' <= char && char <= 'Z' || char == '_' 221 | } 222 | 223 | func isDigit(char byte) bool { 224 | return '0' <= char && char <= '9' 225 | } 226 | 227 | func (l *Lexer) Reset() { 228 | l.cursor = 0 229 | l.readCursor = 0 230 | var char byte 231 | l.char = char 232 | l.readChar() 233 | } 234 | -------------------------------------------------------------------------------- /internal/parse/lexer_test.go: -------------------------------------------------------------------------------- 1 | package parse 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestLexer(t *testing.T) { 9 | tests := []struct { 10 | input string 11 | want []token 12 | }{ 13 | { 14 | input: `SELECT a, b FROM mytable WHERE a = 1 AND b = 'foo'`, 15 | want: []token{ 16 | {typ: selectTok, literal: "SELECT"}, 17 | {typ: identifier, literal: "a"}, 18 | {typ: comma, literal: ","}, 19 | {typ: identifier, literal: "b"}, 20 | {typ: from, literal: "FROM"}, 21 | {typ: identifier, literal: "mytable"}, 22 | {typ: where, literal: "WHERE"}, 23 | {typ: identifier, literal: "a"}, 24 | {typ: equal, literal: "="}, 25 | {typ: number, literal: "1"}, 26 | {typ: and, literal: "AND"}, 27 | {typ: identifier, literal: "b"}, 28 | {typ: equal, literal: "="}, 29 | {typ: stringTok, literal: "foo"}, 30 | {typ: eof, literal: ""}, 31 | }, 32 | }, 33 | { 34 | input: `select a from mytable`, 35 | want: []token{ 36 | {typ: selectTok, literal: "select"}, 37 | {typ: identifier, literal: "a"}, 38 | {typ: from, literal: "from"}, 39 | {typ: identifier, literal: "mytable"}, 40 | {typ: eof, literal: ""}, 41 | }, 42 | }, 43 | { 44 | input: `SELECT a FROM mytable ORDER BY a ASC`, 45 | want: []token{ 46 | {typ: selectTok, literal: "SELECT"}, 47 | {typ: identifier, literal: "a"}, 48 | {typ: from, literal: "FROM"}, 49 | {typ: identifier, literal: "mytable"}, 50 | {typ: order, literal: "ORDER"}, 51 | {typ: by, literal: "BY"}, 52 | {typ: identifier, literal: "a"}, 53 | {typ: asc, literal: "ASC"}, 54 | {typ: eof, literal: ""}, 55 | }, 56 | }, 57 | { 58 | input: `SELECT a FROM mytable ORDER BY a DESC`, 59 | want: []token{ 60 | {typ: selectTok, literal: "SELECT"}, 61 | {typ: identifier, literal: "a"}, 62 | {typ: from, literal: "FROM"}, 63 | {typ: identifier, literal: "mytable"}, 64 | {typ: order, literal: "ORDER"}, 65 | {typ: by, literal: "BY"}, 66 | {typ: identifier, literal: "a"}, 67 | {typ: desc, literal: "DESC"}, 68 | {typ: eof, literal: ""}, 69 | }, 70 | }, 71 | { 72 | input: `SELECT a, b FROM mytable1 JOIN mytable2 ON mytable1.a = mytable2.b`, 73 | want: []token{ 74 | {typ: selectTok, literal: "SELECT"}, 75 | {typ: identifier, literal: "a"}, 76 | {typ: comma, literal: ","}, 77 | {typ: identifier, literal: "b"}, 78 | {typ: from, literal: "FROM"}, 79 | {typ: identifier, literal: "mytable1"}, 80 | {typ: join, literal: "JOIN"}, 81 | {typ: identifier, literal: "mytable2"}, 82 | {typ: on, literal: "ON"}, 83 | {typ: identifier, literal: "mytable1.a"}, 84 | {typ: equal, literal: "="}, 85 | {typ: identifier, literal: "mytable2.b"}, 86 | {typ: eof, literal: ""}, 87 | }, 88 | }, 89 | { 90 | input: `SELECT a, MAX(b) AS max_b FROM mytable GROUP BY a`, 91 | want: []token{ 92 | {typ: selectTok, literal: "SELECT"}, 93 | {typ: identifier, literal: "a"}, 94 | {typ: comma, literal: ","}, 95 | {typ: max, literal: "MAX"}, 96 | {typ: lparen, literal: "("}, 97 | {typ: identifier, literal: "b"}, 98 | {typ: rparen, literal: ")"}, 99 | {typ: as, literal: "AS"}, 100 | {typ: identifier, literal: "max_b"}, 101 | {typ: from, literal: "FROM"}, 102 | {typ: identifier, literal: "mytable"}, 103 | {typ: group, literal: "GROUP"}, 104 | {typ: by, literal: "BY"}, 105 | {typ: identifier, literal: "a"}, 106 | {typ: eof, literal: ""}, 107 | }, 108 | }, 109 | { 110 | input: `SELECT a, COUNT(*) AS max_b FROM mytable GROUP BY a`, 111 | want: []token{ 112 | {typ: selectTok, literal: "SELECT"}, 113 | {typ: identifier, literal: "a"}, 114 | {typ: comma, literal: ","}, 115 | {typ: count, literal: "COUNT"}, 116 | {typ: lparen, literal: "("}, 117 | {typ: asterisk, literal: "*"}, 118 | {typ: rparen, literal: ")"}, 119 | {typ: as, literal: "AS"}, 120 | {typ: identifier, literal: "max_b"}, 121 | {typ: from, literal: "FROM"}, 122 | {typ: identifier, literal: "mytable"}, 123 | {typ: group, literal: "GROUP"}, 124 | {typ: by, literal: "BY"}, 125 | {typ: identifier, literal: "a"}, 126 | {typ: eof, literal: ""}, 127 | }, 128 | }, 129 | { 130 | input: `EXPLAIN SELECT a, b FROM mytable WHERE a = 1 AND b = 'foo'`, 131 | want: []token{ 132 | {typ: explain, literal: "EXPLAIN"}, 133 | {typ: selectTok, literal: "SELECT"}, 134 | {typ: identifier, literal: "a"}, 135 | {typ: comma, literal: ","}, 136 | {typ: identifier, literal: "b"}, 137 | {typ: from, literal: "FROM"}, 138 | {typ: identifier, literal: "mytable"}, 139 | {typ: where, literal: "WHERE"}, 140 | {typ: identifier, literal: "a"}, 141 | {typ: equal, literal: "="}, 142 | {typ: number, literal: "1"}, 143 | {typ: and, literal: "AND"}, 144 | {typ: identifier, literal: "b"}, 145 | {typ: equal, literal: "="}, 146 | {typ: stringTok, literal: "foo"}, 147 | {typ: eof, literal: ""}, 148 | }, 149 | }, 150 | { 151 | input: `INSERT INTO mytable (a, b) VALUES (1, 'foo', $1)`, 152 | want: []token{ 153 | {typ: insert, literal: "INSERT"}, 154 | {typ: into, literal: "INTO"}, 155 | {typ: identifier, literal: "mytable"}, 156 | {typ: lparen, literal: "("}, 157 | {typ: identifier, literal: "a"}, 158 | {typ: comma, literal: ","}, 159 | {typ: identifier, literal: "b"}, 160 | {typ: rparen, literal: ")"}, 161 | {typ: values, literal: "VALUES"}, 162 | {typ: lparen, literal: "("}, 163 | {typ: number, literal: "1"}, 164 | {typ: comma, literal: ","}, 165 | {typ: stringTok, literal: "foo"}, 166 | {typ: comma, literal: ","}, 167 | {typ: placeholder, literal: "$1"}, 168 | {typ: rparen, literal: ")"}, 169 | {typ: eof, literal: ""}, 170 | }, 171 | }, 172 | { 173 | input: `DELETE FROM mytable WHERE a = 1`, 174 | want: []token{ 175 | {typ: deleteTok, literal: "DELETE"}, 176 | {typ: from, literal: "FROM"}, 177 | {typ: identifier, literal: "mytable"}, 178 | {typ: where, literal: "WHERE"}, 179 | {typ: identifier, literal: "a"}, 180 | {typ: equal, literal: "="}, 181 | {typ: number, literal: "1"}, 182 | {typ: eof, literal: ""}, 183 | }, 184 | }, 185 | { 186 | input: `UPDATE mytable SET a = 1 WHERE b = 'foo'`, 187 | want: []token{ 188 | {typ: update, literal: "UPDATE"}, 189 | {typ: identifier, literal: "mytable"}, 190 | {typ: set, literal: "SET"}, 191 | {typ: identifier, literal: "a"}, 192 | {typ: equal, literal: "="}, 193 | {typ: number, literal: "1"}, 194 | {typ: where, literal: "WHERE"}, 195 | {typ: identifier, literal: "b"}, 196 | {typ: equal, literal: "="}, 197 | {typ: stringTok, literal: "foo"}, 198 | {typ: eof, literal: ""}, 199 | }, 200 | }, 201 | { 202 | input: `CREATE TABLE mytable (a INT, b VARCHAR)`, 203 | want: []token{ 204 | {typ: create, literal: "CREATE"}, 205 | {typ: table, literal: "TABLE"}, 206 | {typ: identifier, literal: "mytable"}, 207 | {typ: lparen, literal: "("}, 208 | {typ: identifier, literal: "a"}, 209 | {typ: intTok, literal: "INT"}, 210 | {typ: comma, literal: ","}, 211 | {typ: identifier, literal: "b"}, 212 | {typ: varchar, literal: "VARCHAR"}, 213 | {typ: rparen, literal: ")"}, 214 | {typ: eof, literal: ""}, 215 | }, 216 | }, 217 | { 218 | input: `BEGIN`, 219 | want: []token{ 220 | {typ: begin, literal: "BEGIN"}, 221 | {typ: eof, literal: ""}, 222 | }, 223 | }, 224 | { 225 | input: `COMMIT`, 226 | want: []token{ 227 | {typ: commit, literal: "COMMIT"}, 228 | {typ: eof, literal: ""}, 229 | }, 230 | }, 231 | { 232 | input: `ROLLBACK`, 233 | want: []token{ 234 | {typ: rollback, literal: "ROLLBACK"}, 235 | {typ: eof, literal: ""}, 236 | }, 237 | }, 238 | } 239 | for _, tt := range tests { 240 | t.Run(tt.input, func(t *testing.T) { 241 | l := NewLexer(tt.input) 242 | var got []token 243 | for tok := range l.TokenIterator() { 244 | got = append(got, tok) 245 | } 246 | if !reflect.DeepEqual(got, tt.want) { 247 | t.Errorf("got %v, want %v", got, tt.want) 248 | } 249 | }) 250 | } 251 | } 252 | -------------------------------------------------------------------------------- /internal/parse/parser_test.go: -------------------------------------------------------------------------------- 1 | package parse 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/abekoh/simple-db/internal/query" 8 | "github.com/abekoh/simple-db/internal/record/schema" 9 | ) 10 | 11 | func TestParser_Query(t *testing.T) { 12 | tests := []struct { 13 | name string 14 | s string 15 | want *QueryData 16 | wantErr bool 17 | }{ 18 | { 19 | name: "SELECT full", 20 | s: "SELECT a, b, c FROM mytable WHERE a = 1 AND b = 'foo' AND c = $1", 21 | want: &QueryData{ 22 | fields: []schema.FieldName{"a", "b", "c"}, 23 | tables: []string{"mytable"}, 24 | pred: query.Predicate{ 25 | query.NewTerm(schema.FieldName("a"), schema.ConstantInt32(1)), 26 | query.NewTerm(schema.FieldName("b"), schema.ConstantStr("foo")), 27 | query.NewTerm(schema.FieldName("c"), schema.Placeholder(1)), 28 | }, 29 | }, 30 | wantErr: false, 31 | }, 32 | { 33 | name: "SELECT without where", 34 | s: "SELECT a, b FROM mytable", 35 | want: &QueryData{ 36 | fields: []schema.FieldName{"a", "b"}, 37 | tables: []string{"mytable"}, 38 | }, 39 | wantErr: false, 40 | }, 41 | { 42 | name: "SELECT only one field", 43 | s: "SELECT a FROM mytable", 44 | want: &QueryData{ 45 | fields: []schema.FieldName{"a"}, 46 | tables: []string{"mytable"}, 47 | }, 48 | wantErr: false, 49 | }, 50 | { 51 | name: "SELECT product", 52 | s: "SELECT a, x FROM mytable1, mytable2", 53 | want: &QueryData{ 54 | fields: []schema.FieldName{"a", "x"}, 55 | tables: []string{"mytable1", "mytable2"}, 56 | }, 57 | wantErr: false, 58 | }, 59 | { 60 | name: "SELECT with join", 61 | s: "SELECT a, x FROM mytable1 JOIN mytable2 ON a = x", 62 | want: &QueryData{ 63 | fields: []schema.FieldName{"a", "x"}, 64 | tables: []string{"mytable1", "mytable2"}, 65 | pred: query.Predicate{ 66 | query.NewTerm(schema.FieldName("a"), schema.FieldName("x")), 67 | }, 68 | }, 69 | wantErr: false, 70 | }, 71 | { 72 | name: "SELECT with 2 joins", 73 | s: "SELECT a, x, y FROM mytable1 JOIN mytable2 ON a = x JOIN mytable3 ON x = y", 74 | want: &QueryData{ 75 | fields: []schema.FieldName{"a", "x", "y"}, 76 | tables: []string{"mytable1", "mytable2", "mytable3"}, 77 | pred: query.Predicate{ 78 | query.NewTerm(schema.FieldName("a"), schema.FieldName("x")), 79 | query.NewTerm(schema.FieldName("x"), schema.FieldName("y")), 80 | }, 81 | }, 82 | }, 83 | { 84 | name: "SELECT with 2 joins and where", 85 | s: "SELECT a, x, y FROM mytable1 JOIN mytable2 ON a = x JOIN mytable3 ON x = y WHERE a = 1", 86 | want: &QueryData{ 87 | fields: []schema.FieldName{"a", "x", "y"}, 88 | tables: []string{"mytable1", "mytable2", "mytable3"}, 89 | pred: query.Predicate{ 90 | query.NewTerm(schema.FieldName("a"), schema.FieldName("x")), 91 | query.NewTerm(schema.FieldName("x"), schema.FieldName("y")), 92 | query.NewTerm(schema.FieldName("a"), schema.ConstantInt32(1)), 93 | }, 94 | }, 95 | }, 96 | { 97 | name: "SELECT with ORDER BY", 98 | s: "SELECT a, b FROM mytable ORDER BY a", 99 | want: &QueryData{ 100 | fields: []schema.FieldName{"a", "b"}, 101 | tables: []string{"mytable"}, 102 | order: query.Order{ 103 | query.OrderElement{Field: "a", OrderType: query.Asc}, 104 | }, 105 | }, 106 | }, 107 | { 108 | name: "SELECT with ORDER BY DESC", 109 | s: "SELECT a, b FROM mytable ORDER BY a DESC", 110 | want: &QueryData{ 111 | fields: []schema.FieldName{"a", "b"}, 112 | tables: []string{"mytable"}, 113 | order: query.Order{ 114 | query.OrderElement{Field: "a", OrderType: query.Desc}, 115 | }, 116 | }, 117 | }, 118 | { 119 | name: "SELECT with ORDER BY ASC", 120 | s: "SELECT a, b FROM mytable ORDER BY a ASC", 121 | want: &QueryData{ 122 | fields: []schema.FieldName{"a", "b"}, 123 | tables: []string{"mytable"}, 124 | order: query.Order{ 125 | query.OrderElement{Field: "a", OrderType: query.Asc}, 126 | }, 127 | }, 128 | }, 129 | { 130 | name: "SELECT with GROUP BY", 131 | s: "SELECT a, MAX(b) AS max_b FROM mytable GROUP BY a", 132 | want: &QueryData{ 133 | fields: []schema.FieldName{"a", "max_b"}, 134 | aggregationFuncs: []query.AggregationFunc{ 135 | query.NewMaxFunc("b", "max_b"), 136 | }, 137 | tables: []string{"mytable"}, 138 | groupFields: []schema.FieldName{"a"}, 139 | }, 140 | }, 141 | { 142 | name: "SELECT with GROUP BY using *", 143 | s: "SELECT a, COUNT(*) AS max_b FROM mytable GROUP BY a", 144 | want: &QueryData{ 145 | fields: []schema.FieldName{"a", "max_b"}, 146 | aggregationFuncs: []query.AggregationFunc{ 147 | query.NewCountFunc("*", "max_b"), 148 | }, 149 | tables: []string{"mytable"}, 150 | groupFields: []schema.FieldName{"a"}, 151 | }, 152 | }, 153 | } 154 | for _, tt := range tests { 155 | t.Run(tt.name, func(t *testing.T) { 156 | p := NewParser(tt.s) 157 | got, err := p.Query() 158 | if (err != nil) != tt.wantErr { 159 | t.Errorf("Query() error = %v, wantErr %v", err, tt.wantErr) 160 | return 161 | } 162 | if !reflect.DeepEqual(got, tt.want) { 163 | t.Errorf("Query() got = %v, want %v", got, tt.want) 164 | } 165 | }) 166 | } 167 | } 168 | 169 | func TestParser_Insert(t *testing.T) { 170 | tests := []struct { 171 | name string 172 | s string 173 | want *InsertData 174 | wantErr bool 175 | }{ 176 | { 177 | name: "INSERT", 178 | s: "INSERT INTO mytable (a, b, c) VALUES (1, 'foo', $1)", 179 | want: &InsertData{ 180 | table: "mytable", 181 | fields: []schema.FieldName{"a", "b", "c"}, 182 | values: []schema.Constant{ 183 | schema.ConstantInt32(1), 184 | schema.ConstantStr("foo"), 185 | schema.Placeholder(1), 186 | }, 187 | }, 188 | wantErr: false, 189 | }, 190 | } 191 | for _, tt := range tests { 192 | t.Run(tt.name, func(t *testing.T) { 193 | p := NewParser(tt.s) 194 | got, err := p.Insert() 195 | if (err != nil) != tt.wantErr { 196 | t.Errorf("Insert() error = %v, wantErr %v", err, tt.wantErr) 197 | return 198 | } 199 | if !reflect.DeepEqual(got, tt.want) { 200 | t.Errorf("Insert() got = %v, want %v", got, tt.want) 201 | } 202 | }) 203 | } 204 | } 205 | 206 | func TestParser_Modify(t *testing.T) { 207 | tests := []struct { 208 | name string 209 | s string 210 | want *ModifyData 211 | wantErr bool 212 | }{ 213 | { 214 | name: "UPDATE full", 215 | s: "UPDATE mytable SET a = 1 WHERE b = 'foo'", 216 | want: &ModifyData{ 217 | table: "mytable", 218 | field: "a", 219 | value: schema.ConstantInt32(1), 220 | pred: query.Predicate{ 221 | query.NewTerm( 222 | schema.FieldName("b"), 223 | schema.ConstantStr("foo")), 224 | }, 225 | }, 226 | wantErr: false, 227 | }, 228 | { 229 | name: "UPDATE without where", 230 | s: "UPDATE mytable SET a = 1", 231 | want: &ModifyData{ 232 | table: "mytable", 233 | field: "a", 234 | value: schema.ConstantInt32(1), 235 | }, 236 | wantErr: false, 237 | }, 238 | } 239 | for _, tt := range tests { 240 | t.Run(tt.name, func(t *testing.T) { 241 | p := NewParser(tt.s) 242 | got, err := p.Modify() 243 | if (err != nil) != tt.wantErr { 244 | t.Errorf("Modify() error = %v, wantErr %v", err, tt.wantErr) 245 | return 246 | } 247 | if !reflect.DeepEqual(got, tt.want) { 248 | t.Errorf("Modify() got = %v, want %v", got, tt.want) 249 | } 250 | }) 251 | } 252 | } 253 | 254 | func TestParser_Delete(t *testing.T) { 255 | tests := []struct { 256 | name string 257 | s string 258 | want *DeleteData 259 | wantErr bool 260 | }{ 261 | { 262 | name: "DELETE full", 263 | s: "DELETE FROM mytable WHERE a = 1", 264 | want: &DeleteData{ 265 | table: "mytable", 266 | pred: query.Predicate{ 267 | query.NewTerm( 268 | schema.FieldName("a"), 269 | schema.ConstantInt32(1), 270 | ), 271 | }, 272 | }, 273 | wantErr: false, 274 | }, 275 | { 276 | name: "DELETE without where", 277 | s: "DELETE FROM mytable", 278 | want: &DeleteData{ 279 | table: "mytable", 280 | }, 281 | wantErr: false, 282 | }, 283 | } 284 | for _, tt := range tests { 285 | t.Run(tt.name, func(t *testing.T) { 286 | p := NewParser(tt.s) 287 | got, err := p.Delete() 288 | if (err != nil) != tt.wantErr { 289 | t.Errorf("Delete() error = %v, wantErr %v", err, tt.wantErr) 290 | return 291 | } 292 | if !reflect.DeepEqual(got, tt.want) { 293 | t.Errorf("Delete() got = %v, want %v", got, tt.want) 294 | } 295 | }) 296 | } 297 | } 298 | 299 | func TestParser_CreateTable(t *testing.T) { 300 | tests := []struct { 301 | name string 302 | s string 303 | want *CreateTableData 304 | wantErr bool 305 | }{ 306 | { 307 | name: "CREATE TABLE", 308 | s: "CREATE TABLE mytable (a INT, b VARCHAR(10))", 309 | want: &CreateTableData{ 310 | table: "mytable", 311 | sche: func() schema.Schema { 312 | s := schema.NewSchema() 313 | s.AddField("a", schema.NewInt32Field()) 314 | s.AddField("b", schema.NewVarcharField(10)) 315 | return s 316 | }(), 317 | }, 318 | wantErr: false, 319 | }, 320 | } 321 | for _, tt := range tests { 322 | t.Run(tt.name, func(t *testing.T) { 323 | p := NewParser(tt.s) 324 | got, err := p.CreateTable() 325 | if (err != nil) != tt.wantErr { 326 | t.Errorf("CreateTable() error = %v, wantErr %v", err, tt.wantErr) 327 | return 328 | } 329 | if !reflect.DeepEqual(got, tt.want) { 330 | t.Errorf("CreateTable() got = %v, want %v", got, tt.want) 331 | } 332 | }) 333 | } 334 | } 335 | 336 | func TestParser_CreateView(t *testing.T) { 337 | tests := []struct { 338 | name string 339 | s string 340 | want *CreateViewData 341 | wantErr bool 342 | }{ 343 | { 344 | name: "CREATE VIEW", 345 | s: "CREATE VIEW myview AS SELECT a, b FROM mytable WHERE a = 1 AND b = 'foo'", 346 | want: &CreateViewData{ 347 | view: "myview", 348 | query: &QueryData{ 349 | fields: []schema.FieldName{"a", "b"}, 350 | tables: []string{"mytable"}, 351 | pred: query.Predicate{ 352 | query.NewTerm(schema.FieldName("a"), schema.ConstantInt32(1)), 353 | query.NewTerm(schema.FieldName("b"), schema.ConstantStr("foo")), 354 | }, 355 | }, 356 | }, 357 | wantErr: false, 358 | }, 359 | } 360 | for _, tt := range tests { 361 | t.Run(tt.name, func(t *testing.T) { 362 | p := NewParser(tt.s) 363 | got, err := p.CreateView() 364 | if (err != nil) != tt.wantErr { 365 | t.Errorf("CreateView() error = %v, wantErr %v", err, tt.wantErr) 366 | return 367 | } 368 | if !reflect.DeepEqual(got, tt.want) { 369 | t.Errorf("CreateView() got = %v, want %v", got, tt.want) 370 | } 371 | }) 372 | } 373 | } 374 | 375 | func TestParser_CreateIndex(t *testing.T) { 376 | tests := []struct { 377 | name string 378 | s string 379 | want *CreateIndexData 380 | wantErr bool 381 | }{ 382 | { 383 | name: "CREATE INDEX", 384 | s: "CREATE INDEX myindex ON mytable (a)", 385 | want: &CreateIndexData{ 386 | index: "myindex", 387 | table: "mytable", 388 | field: "a", 389 | }, 390 | wantErr: false, 391 | }, 392 | } 393 | for _, tt := range tests { 394 | t.Run(tt.name, func(t *testing.T) { 395 | p := NewParser(tt.s) 396 | got, err := p.CreateIndex() 397 | if (err != nil) != tt.wantErr { 398 | t.Errorf("CreateIndex() error = %v, wantErr %v", err, tt.wantErr) 399 | return 400 | } 401 | if !reflect.DeepEqual(got, tt.want) { 402 | t.Errorf("CreateIndex() got = %v, want %v", got, tt.want) 403 | } 404 | }) 405 | } 406 | } 407 | -------------------------------------------------------------------------------- /internal/plan/error.go: -------------------------------------------------------------------------------- 1 | package plan 2 | 3 | import "errors" 4 | 5 | var ( 6 | ErrFieldNotFound = errors.New("field not found") 7 | ) 8 | -------------------------------------------------------------------------------- /internal/plan/group.go: -------------------------------------------------------------------------------- 1 | package plan 2 | 3 | import ( 4 | "fmt" 5 | "maps" 6 | "slices" 7 | 8 | "github.com/abekoh/simple-db/internal/query" 9 | "github.com/abekoh/simple-db/internal/record/schema" 10 | "github.com/abekoh/simple-db/internal/statement" 11 | "github.com/abekoh/simple-db/internal/transaction" 12 | ) 13 | 14 | type GroupByScan struct { 15 | scan query.Scan 16 | groupFields []schema.FieldName 17 | groupValues map[schema.FieldName]schema.Constant 18 | aggregationFuncs []query.AggregationFunc 19 | moreGroups bool 20 | } 21 | 22 | var _ query.Scan = (*GroupByScan)(nil) 23 | 24 | func NewGroupByScan(scan query.Scan, fields []schema.FieldName, aggregationFuncs []query.AggregationFunc) (*GroupByScan, error) { 25 | gs := GroupByScan{scan: scan, groupFields: fields, aggregationFuncs: aggregationFuncs} 26 | if err := gs.BeforeFirst(); err != nil { 27 | return nil, err 28 | } 29 | return &gs, nil 30 | } 31 | 32 | func (g *GroupByScan) Val(fieldName schema.FieldName) (schema.Constant, error) { 33 | if slices.Contains(g.groupFields, fieldName) { 34 | return g.groupValues[fieldName], nil 35 | } 36 | for _, f := range g.aggregationFuncs { 37 | if f.AliasName() == fieldName { 38 | return f.Val(), nil 39 | } 40 | } 41 | return nil, ErrFieldNotFound 42 | } 43 | 44 | func (g *GroupByScan) BeforeFirst() error { 45 | if err := g.scan.BeforeFirst(); err != nil { 46 | return fmt.Errorf("g.scan.BeforeFirst error: %w", err) 47 | } 48 | ok, err := g.scan.Next() 49 | if err != nil { 50 | return fmt.Errorf("g.scan.Next error: %w", err) 51 | } 52 | g.moreGroups = ok 53 | return nil 54 | } 55 | 56 | func (g *GroupByScan) Next() (bool, error) { 57 | if !g.moreGroups { 58 | return false, nil 59 | } 60 | for _, f := range g.aggregationFuncs { 61 | if err := f.First(g.scan); err != nil { 62 | return false, fmt.Errorf("f.First error: %w", err) 63 | } 64 | } 65 | g.groupValues = make(map[schema.FieldName]schema.Constant) 66 | for _, f := range g.groupFields { 67 | val, err := g.scan.Val(f) 68 | if err != nil { 69 | return false, fmt.Errorf("g.scan.Val error: %w", err) 70 | } 71 | g.groupValues[f] = val 72 | } 73 | for { 74 | ok, err := g.scan.Next() 75 | if err != nil { 76 | return false, fmt.Errorf("g.scan.Next error: %w", err) 77 | } 78 | g.moreGroups = ok 79 | if !g.moreGroups { 80 | return true, nil 81 | } 82 | gv := make(map[schema.FieldName]schema.Constant) 83 | for _, f := range g.groupFields { 84 | val, err := g.scan.Val(f) 85 | if err != nil { 86 | return false, fmt.Errorf("g.scan.Val error: %w", err) 87 | } 88 | gv[f] = val 89 | } 90 | if !maps.Equal(g.groupValues, gv) { 91 | return true, nil 92 | } 93 | for _, f := range g.aggregationFuncs { 94 | if err := f.Next(g.scan); err != nil { 95 | return false, fmt.Errorf("f.Next error: %w", err) 96 | } 97 | } 98 | } 99 | } 100 | 101 | func (g *GroupByScan) Int32(fieldName schema.FieldName) (int32, error) { 102 | v, err := g.Val(fieldName) 103 | if err != nil { 104 | return 0, err 105 | } 106 | intV, ok := v.(schema.ConstantInt32) 107 | if !ok { 108 | return 0, schema.ErrTypeAssertionFailed 109 | } 110 | return int32(intV), nil 111 | } 112 | 113 | func (g *GroupByScan) Str(fieldName schema.FieldName) (string, error) { 114 | v, err := g.Val(fieldName) 115 | if err != nil { 116 | return "", err 117 | } 118 | strV, ok := v.(schema.ConstantStr) 119 | if !ok { 120 | return "", schema.ErrTypeAssertionFailed 121 | } 122 | return string(strV), nil 123 | } 124 | 125 | func (g *GroupByScan) HasField(fieldName schema.FieldName) bool { 126 | if slices.Contains(g.groupFields, fieldName) { 127 | return true 128 | } 129 | for _, f := range g.aggregationFuncs { 130 | if f.AliasName() == fieldName { 131 | return true 132 | } 133 | } 134 | return false 135 | } 136 | 137 | func (g *GroupByScan) Close() error { 138 | if err := g.scan.Close(); err != nil { 139 | return fmt.Errorf("g.scan.Close error: %w", err) 140 | } 141 | return nil 142 | } 143 | 144 | type GroupByPlan struct { 145 | p Plan 146 | groupFields []schema.FieldName 147 | aggregationFuncs []query.AggregationFunc 148 | sche schema.Schema 149 | } 150 | 151 | var _ Plan = (*GroupByPlan)(nil) 152 | 153 | func NewGroupByPlan(tx *transaction.Transaction, p Plan, groupFields []schema.FieldName, aggregationFuncs []query.AggregationFunc) *GroupByPlan { 154 | s := schema.NewSchema() 155 | for _, fn := range groupFields { 156 | s.Add(fn, *p.Schema()) 157 | } 158 | for _, f := range aggregationFuncs { 159 | s.AddInt32Field(f.AliasName()) 160 | } 161 | order := make(query.Order, len(groupFields)) 162 | for i, f := range groupFields { 163 | order[i] = query.OrderElement{ 164 | Field: f, 165 | OrderType: query.Asc, 166 | } 167 | } 168 | return &GroupByPlan{ 169 | p: NewSortPlan(tx, p, order), 170 | groupFields: groupFields, 171 | aggregationFuncs: aggregationFuncs, 172 | sche: s, 173 | } 174 | } 175 | 176 | func (g GroupByPlan) Result() {} 177 | 178 | func (g GroupByPlan) Placeholders(findSchema func(tableName string) (*schema.Schema, error)) map[int]schema.FieldType { 179 | return g.p.Placeholders(findSchema) 180 | } 181 | 182 | func (g GroupByPlan) SwapParams(params map[int]schema.Constant) (statement.Bound, error) { 183 | newP, err := g.p.SwapParams(params) 184 | if err != nil { 185 | return nil, fmt.Errorf("g.p.SwapParams error: %w", err) 186 | } 187 | return &BoundPlan{ 188 | Plan: newP.(Plan), 189 | }, nil 190 | } 191 | 192 | func (g GroupByPlan) Open() (query.Scan, error) { 193 | s, err := g.p.Open() 194 | if err != nil { 195 | return nil, fmt.Errorf("g.p.Open error: %w", err) 196 | } 197 | gs, err := NewGroupByScan(s, g.groupFields, g.aggregationFuncs) 198 | if err != nil { 199 | return nil, fmt.Errorf("NewGroupByScan error: %w", err) 200 | } 201 | return gs, nil 202 | } 203 | 204 | func (g GroupByPlan) BlockAccessed() int { 205 | return g.p.BlockAccessed() 206 | } 207 | 208 | func (g GroupByPlan) RecordsOutput() int { 209 | numGroups := 1 210 | for _, f := range g.groupFields { 211 | numGroups *= g.p.DistinctValues(f) 212 | } 213 | return numGroups 214 | } 215 | 216 | func (g GroupByPlan) DistinctValues(fieldName schema.FieldName) int { 217 | if slices.Contains(g.groupFields, fieldName) { 218 | return g.p.DistinctValues(fieldName) 219 | } 220 | return g.RecordsOutput() 221 | } 222 | 223 | func (g GroupByPlan) Schema() *schema.Schema { 224 | return &g.sche 225 | } 226 | 227 | func (g GroupByPlan) Info() Info { 228 | conditions := make(map[string][]string) 229 | for _, f := range g.groupFields { 230 | conditions["groupFields"] = append(conditions["groupFields"], string(f)) 231 | } 232 | for _, f := range g.aggregationFuncs { 233 | conditions["aggregationFuncs"] = append(conditions["aggregationFuncs"], f.String()) 234 | } 235 | return Info{ 236 | NodeType: "GroupBy", 237 | Conditions: conditions, 238 | BlockAccessed: g.BlockAccessed(), 239 | RecordsOutput: g.RecordsOutput(), 240 | Children: []Info{g.p.Info()}, 241 | } 242 | } 243 | -------------------------------------------------------------------------------- /internal/plan/group_test.go: -------------------------------------------------------------------------------- 1 | package plan_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "reflect" 7 | "testing" 8 | 9 | "github.com/abekoh/simple-db/internal/plan" 10 | "github.com/abekoh/simple-db/internal/query" 11 | "github.com/abekoh/simple-db/internal/record" 12 | "github.com/abekoh/simple-db/internal/record/schema" 13 | "github.com/abekoh/simple-db/internal/simpledb" 14 | "github.com/abekoh/simple-db/internal/transaction" 15 | ) 16 | 17 | func TestGroupByPlan(t *testing.T) { 18 | transaction.CleanupLockTable(t) 19 | ctx := context.Background() 20 | db, err := simpledb.New(ctx, t.TempDir()) 21 | if err != nil { 22 | t.Fatal(err) 23 | } 24 | tx, err := db.NewTx(ctx) 25 | if err != nil { 26 | t.Fatal(err) 27 | } 28 | 29 | sche := schema.NewSchema() 30 | sche.AddStrField("department", 10) 31 | sche.AddInt32Field("score") 32 | if err := db.MetadataMgr().CreateTable("mytable", sche, tx); err != nil { 33 | t.Fatal(err) 34 | } 35 | 36 | layout := record.NewLayoutSchema(sche) 37 | updateScan, err := record.NewTableScan(tx, "mytable", layout) 38 | if err != nil { 39 | t.Fatal(err) 40 | } 41 | 42 | if err := updateScan.BeforeFirst(); err != nil { 43 | t.Fatal(err) 44 | } 45 | for _, v := range []struct { 46 | department string 47 | score int32 48 | }{ 49 | {"math", 93}, 50 | {"math", 87}, 51 | {"math", 92}, 52 | {"math", 85}, 53 | {"english", 85}, 54 | {"english", 90}, 55 | {"english", 88}, 56 | } { 57 | if err := updateScan.Insert(); err != nil { 58 | t.Fatal(err) 59 | } 60 | if err := updateScan.SetVal("department", schema.ConstantStr(v.department)); err != nil { 61 | t.Fatal(err) 62 | } 63 | if err := updateScan.SetVal("score", schema.ConstantInt32(v.score)); err != nil { 64 | t.Fatal(err) 65 | } 66 | } 67 | if err := updateScan.Close(); err != nil { 68 | t.Fatal(err) 69 | } 70 | 71 | tablePlan, err := plan.NewTablePlan("mytable", tx, db.MetadataMgr()) 72 | if err != nil { 73 | t.Fatal(err) 74 | } 75 | groupByPlan := plan.NewGroupByPlan(tx, 76 | tablePlan, 77 | []schema.FieldName{"department"}, 78 | []query.AggregationFunc{ 79 | query.NewCountFunc("*", "count_score"), 80 | query.NewMaxFunc("score", "max_score"), 81 | query.NewMinFunc("score", "min_score"), 82 | query.NewSumFunc("score", "sum_score"), 83 | }, 84 | ) 85 | sortPlan := plan.NewSortPlan(tx, groupByPlan, query.Order{query.OrderElement{Field: "department", OrderType: query.Asc}}) 86 | projectPlan := plan.NewProjectPlan(sortPlan, []schema.FieldName{"department", "count_score", "max_score", "min_score"}) 87 | 88 | queryScan, err := projectPlan.Open() 89 | if err != nil { 90 | t.Fatal(err) 91 | } 92 | if err := queryScan.BeforeFirst(); err != nil { 93 | t.Fatal(err) 94 | } 95 | 96 | res := make([]string, 0, 2) 97 | for { 98 | ok, err := queryScan.Next() 99 | if err != nil { 100 | t.Fatal(err) 101 | } 102 | if !ok { 103 | break 104 | } 105 | department, err := queryScan.Str("department") 106 | if err != nil { 107 | t.Fatal(err) 108 | } 109 | countScore, err := queryScan.Int32("count_score") 110 | if err != nil { 111 | t.Fatal(err) 112 | } 113 | maxScore, err := queryScan.Int32("max_score") 114 | if err != nil { 115 | t.Fatal(err) 116 | } 117 | minScore, err := queryScan.Int32("min_score") 118 | if err != nil { 119 | t.Fatal(err) 120 | } 121 | res = append(res, fmt.Sprintf("%s: count=%d, max=%d, min=%d", department, countScore, maxScore, minScore)) 122 | } 123 | if err := queryScan.Close(); err != nil { 124 | t.Fatal(err) 125 | } 126 | if !reflect.DeepEqual(res, []string{ 127 | "english: count=3, max=90, min=85", 128 | "math: count=4, max=93, min=85", 129 | }) { 130 | t.Fatalf("unexpected result: %v", res) 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /internal/plan/materialize.go: -------------------------------------------------------------------------------- 1 | package plan 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | 7 | "github.com/abekoh/simple-db/internal/query" 8 | "github.com/abekoh/simple-db/internal/record" 9 | "github.com/abekoh/simple-db/internal/record/schema" 10 | "github.com/abekoh/simple-db/internal/statement" 11 | "github.com/abekoh/simple-db/internal/transaction" 12 | "github.com/oklog/ulid/v2" 13 | ) 14 | 15 | type TempTable struct { 16 | tx *transaction.Transaction 17 | tableName string 18 | layout *record.Layout 19 | } 20 | 21 | func NewTempTable(tx *transaction.Transaction, sche schema.Schema) *TempTable { 22 | l := record.NewLayoutSchema(sche) 23 | return &TempTable{tx: tx, tableName: fmt.Sprintf("temp_table_%s", ulid.Make()), layout: l} 24 | } 25 | 26 | func (t TempTable) Open() (query.UpdateScan, error) { 27 | ts, err := record.NewTableScan(t.tx, t.tableName, t.layout) 28 | if err != nil { 29 | return nil, fmt.Errorf("record.NewTableScan error: %w", err) 30 | } 31 | return ts, nil 32 | } 33 | 34 | func (t TempTable) TableName() string { 35 | return t.tableName 36 | } 37 | 38 | func (t TempTable) Layout() *record.Layout { 39 | return t.layout 40 | } 41 | 42 | type MaterializePlan struct { 43 | srcPlan Plan 44 | tx *transaction.Transaction 45 | } 46 | 47 | var _ Plan = (*MaterializePlan)(nil) 48 | 49 | func NewMaterializePlan(tx *transaction.Transaction, p Plan) *MaterializePlan { 50 | return &MaterializePlan{srcPlan: p, tx: tx} 51 | } 52 | 53 | func (p MaterializePlan) Result() {} 54 | 55 | func (p MaterializePlan) Info() Info { 56 | return Info{ 57 | NodeType: "Materialize", 58 | BlockAccessed: p.BlockAccessed(), 59 | RecordsOutput: p.RecordsOutput(), 60 | Children: []Info{p.srcPlan.Info()}, 61 | } 62 | } 63 | 64 | func (p MaterializePlan) Placeholders(findSchema func(tableName string) (*schema.Schema, error)) map[int]schema.FieldType { 65 | return p.srcPlan.Placeholders(findSchema) 66 | } 67 | 68 | func (p MaterializePlan) SwapParams(params map[int]schema.Constant) (statement.Bound, error) { 69 | newSrcPlan, err := p.srcPlan.SwapParams(params) 70 | if err != nil { 71 | return nil, fmt.Errorf("srcPlan.SwapParams error: %w", err) 72 | } 73 | np, ok := newSrcPlan.(BoundPlan) 74 | if !ok { 75 | return nil, fmt.Errorf("newSrcPlan is not a plan.Plan") 76 | } 77 | return &BoundPlan{ 78 | Plan: NewMaterializePlan(p.tx, np.Plan), 79 | }, nil 80 | } 81 | 82 | func (p MaterializePlan) Open() (query.Scan, error) { 83 | sche := p.srcPlan.Schema() 84 | temp := NewTempTable(p.tx, *sche) 85 | src, err := p.srcPlan.Open() 86 | if err != nil { 87 | return nil, fmt.Errorf("temp.Open error: %w", err) 88 | } 89 | dest, err := temp.Open() 90 | if err != nil { 91 | return nil, fmt.Errorf("temp.Open error: %w", err) 92 | } 93 | for { 94 | ok, err := src.Next() 95 | if err != nil { 96 | return nil, fmt.Errorf("src.Next error: %w", err) 97 | } 98 | if !ok { 99 | break 100 | } 101 | if err := dest.Insert(); err != nil { 102 | return nil, fmt.Errorf("us.Insert error: %w", err) 103 | } 104 | for _, fldName := range sche.FieldNames() { 105 | val, err := src.Val(fldName) 106 | if err != nil { 107 | return nil, fmt.Errorf("src.Val error: %w", err) 108 | } 109 | if err := dest.SetVal(fldName, val); err != nil { 110 | return nil, fmt.Errorf("us.SetVal error: %w", err) 111 | } 112 | } 113 | } 114 | if err := src.Close(); err != nil { 115 | return nil, fmt.Errorf("src.Close error: %w", err) 116 | } 117 | if err := dest.BeforeFirst(); err != nil { 118 | return nil, fmt.Errorf("us.BeforeFirst error: %w", err) 119 | } 120 | return dest, nil 121 | } 122 | 123 | func (p MaterializePlan) BlockAccessed() int { 124 | l := record.NewLayoutSchema(*p.srcPlan.Schema()) 125 | rpb := float64(p.tx.BlockSize()) / float64(l.SlotSize()) 126 | return int(math.Ceil(float64(p.srcPlan.RecordsOutput()) / rpb)) 127 | } 128 | 129 | func (p MaterializePlan) RecordsOutput() int { 130 | return p.srcPlan.RecordsOutput() 131 | } 132 | 133 | func (p MaterializePlan) DistinctValues(fieldName schema.FieldName) int { 134 | return p.srcPlan.DistinctValues(fieldName) 135 | } 136 | 137 | func (p MaterializePlan) Schema() *schema.Schema { 138 | return p.srcPlan.Schema() 139 | } 140 | -------------------------------------------------------------------------------- /internal/plan/mergejoin.go: -------------------------------------------------------------------------------- 1 | package plan 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/abekoh/simple-db/internal/query" 7 | "github.com/abekoh/simple-db/internal/record/schema" 8 | "github.com/abekoh/simple-db/internal/statement" 9 | "github.com/abekoh/simple-db/internal/transaction" 10 | ) 11 | 12 | type MergeJoinScan struct { 13 | s1 query.Scan 14 | s2 SortScan 15 | fieldName1, fieldName2 schema.FieldName 16 | joinValue schema.Constant 17 | } 18 | 19 | var _ query.Scan = (*MergeJoinScan)(nil) 20 | 21 | func NewMergeJoinScan(s1 query.Scan, s2 SortScan, fieldName1, fieldName2 schema.FieldName) (*MergeJoinScan, error) { 22 | ms := MergeJoinScan{s1: s1, s2: s2, fieldName1: fieldName1, fieldName2: fieldName2} 23 | if err := ms.BeforeFirst(); err != nil { 24 | return nil, fmt.Errorf("ms.BeforeFirst error: %w", err) 25 | } 26 | return &ms, nil 27 | } 28 | 29 | func (m *MergeJoinScan) Val(fieldName schema.FieldName) (schema.Constant, error) { 30 | if m.s1.HasField(fieldName) { 31 | v, err := m.s1.Val(fieldName) 32 | if err != nil { 33 | return nil, fmt.Errorf("s1.Val error: %w", err) 34 | } 35 | return v, nil 36 | } else { 37 | v, err := m.s2.Val(fieldName) 38 | if err != nil { 39 | return nil, fmt.Errorf("s2.Val error: %w", err) 40 | } 41 | return v, nil 42 | } 43 | } 44 | 45 | func (m *MergeJoinScan) BeforeFirst() error { 46 | if err := m.s1.BeforeFirst(); err != nil { 47 | return fmt.Errorf("s1.BeforeFirst error: %w", err) 48 | } 49 | if err := m.s2.BeforeFirst(); err != nil { 50 | return fmt.Errorf("s2.BeforeFirst error: %w", err) 51 | } 52 | return nil 53 | } 54 | 55 | func (m *MergeJoinScan) Next() (bool, error) { 56 | ok2, err := m.s2.Next() 57 | if err != nil { 58 | return false, fmt.Errorf("s2.Next error: %w", err) 59 | } 60 | if ok2 { 61 | v2, err := m.s2.Val(m.fieldName2) 62 | if err != nil { 63 | return false, fmt.Errorf("s2.Val error: %w", err) 64 | } 65 | if m.joinValue != nil && v2.Equals(m.joinValue) { 66 | return true, nil 67 | } 68 | } 69 | ok1, err := m.s1.Next() 70 | if err != nil { 71 | return false, fmt.Errorf("s1.Next error: %w", err) 72 | } 73 | if ok1 { 74 | v1, err := m.s1.Val(m.fieldName1) 75 | if err != nil { 76 | return false, fmt.Errorf("s1.Val error: %w", err) 77 | } 78 | if m.joinValue != nil && v1.Equals(m.joinValue) { 79 | if err := m.s2.RestorePosition(); err != nil { 80 | return false, fmt.Errorf("s2.RestorePosition error: %w", err) 81 | } 82 | return true, nil 83 | } 84 | } 85 | for ok1 && ok2 { 86 | v1, err := m.s1.Val(m.fieldName1) 87 | if err != nil { 88 | return false, fmt.Errorf("s1.Val error: %w", err) 89 | } 90 | v2, err := m.s2.Val(m.fieldName2) 91 | if err != nil { 92 | return false, fmt.Errorf("s2.Val error: %w", err) 93 | } 94 | if v1.Compare(v2) < 0 { 95 | ok1, err = m.s1.Next() 96 | if err != nil { 97 | return false, fmt.Errorf("s1.Next error: %w", err) 98 | } 99 | } else if v1.Compare(v2) > 0 { 100 | ok2, err = m.s2.Next() 101 | if err != nil { 102 | return false, fmt.Errorf("s2.Next error: %w", err) 103 | } 104 | } else { 105 | m.s2.SavePosition() 106 | jv, err := m.s2.Val(m.fieldName2) 107 | if err != nil { 108 | return false, fmt.Errorf("s2.Val error: %w", err) 109 | } 110 | m.joinValue = jv 111 | return true, nil 112 | } 113 | } 114 | return false, nil 115 | } 116 | 117 | func (m *MergeJoinScan) Int32(fieldName schema.FieldName) (int32, error) { 118 | if m.s1.HasField(fieldName) { 119 | v, err := m.s1.Int32(fieldName) 120 | if err != nil { 121 | return 0, fmt.Errorf("s1.Int32 error: %w", err) 122 | } 123 | return v, nil 124 | } else { 125 | v, err := m.s2.Int32(fieldName) 126 | if err != nil { 127 | return 0, fmt.Errorf("s2.Int32 error: %w", err) 128 | } 129 | return v, nil 130 | } 131 | } 132 | 133 | func (m *MergeJoinScan) Str(fieldName schema.FieldName) (string, error) { 134 | if m.s1.HasField(fieldName) { 135 | v, err := m.s1.Str(fieldName) 136 | if err != nil { 137 | return "", fmt.Errorf("s1.Str error: %w", err) 138 | } 139 | return v, nil 140 | } else { 141 | v, err := m.s2.Str(fieldName) 142 | if err != nil { 143 | return "", fmt.Errorf("s2.Str error: %w", err) 144 | } 145 | return v, nil 146 | } 147 | } 148 | 149 | func (m *MergeJoinScan) HasField(fieldName schema.FieldName) bool { 150 | return m.s1.HasField(fieldName) || m.s2.HasField(fieldName) 151 | } 152 | 153 | func (m *MergeJoinScan) Close() error { 154 | if err := m.s1.Close(); err != nil { 155 | return fmt.Errorf("s1.Close error: %w", err) 156 | } 157 | if err := m.s2.Close(); err != nil { 158 | return fmt.Errorf("s2.Close error: %w", err) 159 | } 160 | return nil 161 | } 162 | 163 | type MergeJoinPlan struct { 164 | tx *transaction.Transaction 165 | p1, p2 Plan 166 | fieldName1, fieldName2 schema.FieldName 167 | sche schema.Schema 168 | } 169 | 170 | var _ Plan = (*MergeJoinPlan)(nil) 171 | 172 | func NewMergeJoinPlan(tx *transaction.Transaction, p1, p2 Plan, fieldName1, fieldName2 schema.FieldName) (*MergeJoinPlan, error) { 173 | sp1 := NewSortPlan(tx, p1, query.Order{query.OrderElement{Field: fieldName1, OrderType: query.Asc}}) 174 | sp2 := NewSortPlan(tx, p2, query.Order{query.OrderElement{Field: fieldName2, OrderType: query.Asc}}) 175 | sche := schema.NewSchema() 176 | sche.AddAll(*p1.Schema()) 177 | sche.AddAll(*p2.Schema()) 178 | return &MergeJoinPlan{tx: tx, p1: sp1, p2: sp2, fieldName1: fieldName1, fieldName2: fieldName2, sche: sche}, nil 179 | } 180 | 181 | func (m MergeJoinPlan) Result() {} 182 | 183 | func (m MergeJoinPlan) Info() Info { 184 | return Info{ 185 | NodeType: "MergeJoin", 186 | BlockAccessed: m.BlockAccessed(), 187 | RecordsOutput: m.RecordsOutput(), 188 | Children: []Info{m.p1.Info(), m.p2.Info()}, 189 | } 190 | } 191 | 192 | func (m MergeJoinPlan) Placeholders(findSchema func(tableName string) (*schema.Schema, error)) map[int]schema.FieldType { 193 | placeholders := m.p1.Placeholders(findSchema) 194 | for k, v := range m.p2.Placeholders(findSchema) { 195 | placeholders[k] = v 196 | } 197 | return placeholders 198 | } 199 | 200 | func (m MergeJoinPlan) SwapParams(params map[int]schema.Constant) (statement.Bound, error) { 201 | newP1, err := m.p1.SwapParams(params) 202 | if err != nil { 203 | return nil, fmt.Errorf("p1.SwapParams error: %w", err) 204 | } 205 | newBP1, ok := newP1.(*BoundPlan) 206 | if !ok { 207 | return nil, fmt.Errorf("newP1 is not a plan.BoundPlan") 208 | } 209 | newP2, err := m.p2.SwapParams(params) 210 | if err != nil { 211 | return nil, fmt.Errorf("p2.SwapParams error: %w", err) 212 | } 213 | newBP2, ok := newP2.(*BoundPlan) 214 | if !ok { 215 | return nil, fmt.Errorf("newP2 is not a plan.BoundPlan") 216 | } 217 | newMergeJoinPlan, err := NewMergeJoinPlan(m.tx, newBP1, newBP2, m.fieldName1, m.fieldName2) 218 | if err != nil { 219 | return nil, fmt.Errorf("NewMergeJoinPlan error: %w", err) 220 | } 221 | return &BoundPlan{ 222 | Plan: newMergeJoinPlan, 223 | }, nil 224 | } 225 | 226 | func (m MergeJoinPlan) Open() (query.Scan, error) { 227 | s1, err := m.p1.Open() 228 | if err != nil { 229 | return nil, fmt.Errorf("p1.Open error: %w", err) 230 | } 231 | s2, err := m.p2.Open() 232 | if err != nil { 233 | return nil, fmt.Errorf("p2.Open error: %w", err) 234 | } 235 | sortScane, ok := s2.(*SortScan) 236 | if !ok { 237 | return nil, fmt.Errorf("s2 is not a SortScan") 238 | } 239 | ms, err := NewMergeJoinScan(s1, *sortScane, m.fieldName1, m.fieldName2) 240 | if err != nil { 241 | return nil, fmt.Errorf("NewMergeJoinScan error: %w", err) 242 | } 243 | return ms, nil 244 | } 245 | 246 | func (m MergeJoinPlan) BlockAccessed() int { 247 | return m.p1.BlockAccessed() + m.p2.BlockAccessed() 248 | } 249 | 250 | func (m MergeJoinPlan) RecordsOutput() int { 251 | return (m.p1.RecordsOutput() + m.p2.RecordsOutput()) / max(m.p1.DistinctValues(m.fieldName1), m.p2.DistinctValues(m.fieldName2)) 252 | } 253 | 254 | func (m MergeJoinPlan) DistinctValues(fieldName schema.FieldName) int { 255 | if m.p1.Schema().HasField(fieldName) { 256 | return m.p1.DistinctValues(fieldName) 257 | } 258 | return m.p2.DistinctValues(fieldName) 259 | } 260 | 261 | func (m MergeJoinPlan) Schema() *schema.Schema { 262 | return &m.sche 263 | } 264 | -------------------------------------------------------------------------------- /internal/plan/mergejoin_test.go: -------------------------------------------------------------------------------- 1 | package plan_test 2 | 3 | import ( 4 | "context" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/abekoh/simple-db/internal/plan" 9 | "github.com/abekoh/simple-db/internal/query" 10 | "github.com/abekoh/simple-db/internal/record" 11 | "github.com/abekoh/simple-db/internal/record/schema" 12 | "github.com/abekoh/simple-db/internal/simpledb" 13 | "github.com/abekoh/simple-db/internal/transaction" 14 | ) 15 | 16 | func TestMergeJoinPlan(t *testing.T) { 17 | transaction.CleanupLockTable(t) 18 | ctx := context.Background() 19 | db, err := simpledb.New(ctx, t.TempDir()) 20 | if err != nil { 21 | t.Fatal(err) 22 | } 23 | tx, err := db.NewTx(ctx) 24 | if err != nil { 25 | t.Fatal(err) 26 | } 27 | 28 | sche1 := schema.NewSchema() 29 | sche1.AddInt32Field("department_id") 30 | sche1.AddStrField("department_name", 10) 31 | if err := db.MetadataMgr().CreateTable("departments", sche1, tx); err != nil { 32 | t.Fatal(err) 33 | } 34 | layout1 := record.NewLayoutSchema(sche1) 35 | updateScan1, err := record.NewTableScan(tx, "departments", layout1) 36 | if err != nil { 37 | t.Fatal(err) 38 | } 39 | if err := updateScan1.BeforeFirst(); err != nil { 40 | t.Fatal(err) 41 | } 42 | for _, v := range []struct { 43 | departmentID int32 44 | departmentName string 45 | }{ 46 | {10, "compsci"}, 47 | {18, "basketry"}, 48 | {20, "math"}, 49 | {30, "drama"}, 50 | } { 51 | if err := updateScan1.Insert(); err != nil { 52 | t.Fatal(err) 53 | } 54 | if err := updateScan1.SetInt32("department_id", v.departmentID); err != nil { 55 | t.Fatal(err) 56 | } 57 | if err := updateScan1.SetStr("department_name", v.departmentName); err != nil { 58 | t.Fatal(err) 59 | } 60 | } 61 | if err := updateScan1.Close(); err != nil { 62 | t.Fatal(err) 63 | } 64 | 65 | sche2 := schema.NewSchema() 66 | sche2.AddInt32Field("student_id") 67 | sche2.AddStrField("student_name", 10) 68 | sche2.AddInt32Field("department_id") 69 | if err := db.MetadataMgr().CreateTable("students", sche2, tx); err != nil { 70 | t.Fatal(err) 71 | } 72 | layout2 := record.NewLayoutSchema(sche2) 73 | updateScan2, err := record.NewTableScan(tx, "students", layout2) 74 | if err != nil { 75 | t.Fatal(err) 76 | } 77 | if err := updateScan2.BeforeFirst(); err != nil { 78 | t.Fatal(err) 79 | } 80 | for _, v := range []struct { 81 | studentID int32 82 | studentName string 83 | departmentID int32 84 | }{ 85 | {4, "sue", 20}, 86 | {1, "joe", 10}, 87 | {5, "bob", 30}, 88 | {2, "amy", 20}, 89 | {6, "kim", 20}, 90 | {3, "max", 10}, 91 | {8, "pat", 20}, 92 | {7, "art", 30}, 93 | {9, "lee", 10}, 94 | } { 95 | if err := updateScan2.Insert(); err != nil { 96 | t.Fatal(err) 97 | } 98 | if err := updateScan2.SetInt32("student_id", v.studentID); err != nil { 99 | t.Fatal(err) 100 | } 101 | if err := updateScan2.SetStr("student_name", v.studentName); err != nil { 102 | t.Fatal(err) 103 | } 104 | if err := updateScan2.SetInt32("department_id", v.departmentID); err != nil { 105 | t.Fatal(err) 106 | } 107 | } 108 | if err := updateScan2.Close(); err != nil { 109 | t.Fatal(err) 110 | } 111 | 112 | tablePlan1, err := plan.NewTablePlan("departments", tx, db.MetadataMgr()) 113 | if err != nil { 114 | t.Fatal(err) 115 | } 116 | tablePlan2, err := plan.NewTablePlan("students", tx, db.MetadataMgr()) 117 | if err != nil { 118 | t.Fatal(err) 119 | } 120 | joinPlan, err := plan.NewMergeJoinPlan(tx, tablePlan1, tablePlan2, "department_id", "department_id") 121 | if err != nil { 122 | t.Fatal(err) 123 | } 124 | sortPlan := plan.NewSortPlan(tx, joinPlan, query.Order{query.OrderElement{Field: "student_name", OrderType: query.Asc}}) 125 | projectPlan := plan.NewProjectPlan(sortPlan, []schema.FieldName{"student_name", "department_name"}) 126 | 127 | queryScan, err := projectPlan.Open() 128 | if err != nil { 129 | t.Fatal(err) 130 | } 131 | if err := queryScan.BeforeFirst(); err != nil { 132 | t.Fatal(err) 133 | } 134 | res := make([]string, 0, 9) 135 | for { 136 | ok, err := queryScan.Next() 137 | if err != nil { 138 | t.Fatal(err) 139 | } 140 | if !ok { 141 | break 142 | } 143 | studentName, err := queryScan.Str("student_name") 144 | if err != nil { 145 | t.Fatal(err) 146 | } 147 | departmentName, err := queryScan.Str("department_name") 148 | if err != nil { 149 | t.Fatal(err) 150 | } 151 | res = append(res, studentName+":"+departmentName) 152 | } 153 | if err := queryScan.Close(); err != nil { 154 | t.Fatal(err) 155 | } 156 | if len(res) != 9 { 157 | t.Fatalf("got %d, want 9", len(res)) 158 | } 159 | if !reflect.DeepEqual(res, []string{ 160 | "amy:math", 161 | "art:drama", 162 | "bob:drama", 163 | "joe:compsci", 164 | "kim:math", 165 | "lee:compsci", 166 | "max:compsci", 167 | "pat:math", 168 | "sue:math", 169 | }) { 170 | t.Fatalf("unexpected result: %v", res) 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /internal/plan/multibuffer_test.go: -------------------------------------------------------------------------------- 1 | package plan_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/abekoh/simple-db/internal/plan" 10 | "github.com/abekoh/simple-db/internal/query" 11 | "github.com/abekoh/simple-db/internal/record" 12 | "github.com/abekoh/simple-db/internal/record/schema" 13 | "github.com/abekoh/simple-db/internal/simpledb" 14 | "github.com/abekoh/simple-db/internal/transaction" 15 | ) 16 | 17 | func TestProductPlan(t *testing.T) { 18 | transaction.CleanupLockTable(t) 19 | ctx := context.Background() 20 | db, err := simpledb.New(ctx, t.TempDir()) 21 | if err != nil { 22 | t.Fatal(err) 23 | } 24 | tx, err := db.NewTx(ctx) 25 | if err != nil { 26 | t.Fatal(err) 27 | } 28 | 29 | sche1 := schema.NewSchema() 30 | sche1.AddInt32Field("A") 31 | sche1.AddStrField("B", 9) 32 | if err := db.MetadataMgr().CreateTable("T1", sche1, tx); err != nil { 33 | t.Fatal(err) 34 | } 35 | layout1 := record.NewLayoutSchema(sche1) 36 | us1, err := record.NewTableScan(tx, "T1", layout1) 37 | if err != nil { 38 | t.Fatal(err) 39 | } 40 | if err := us1.BeforeFirst(); err != nil { 41 | t.Fatal(err) 42 | } 43 | n := 5 44 | for i := 0; i < n; i++ { 45 | if err := us1.Insert(); err != nil { 46 | t.Fatal(err) 47 | } 48 | if err := us1.SetInt32("A", int32(i)); err != nil { 49 | t.Fatal(err) 50 | } 51 | if err := us1.SetStr("B", fmt.Sprintf("bbb%d", i)); err != nil { 52 | t.Fatal(err) 53 | } 54 | } 55 | if err := us1.Close(); err != nil { 56 | t.Fatal(err) 57 | } 58 | 59 | sche2 := schema.NewSchema() 60 | sche2.AddInt32Field("C") 61 | sche2.AddStrField("D", 9) 62 | if err := db.MetadataMgr().CreateTable("T2", sche2, tx); err != nil { 63 | t.Fatal(err) 64 | } 65 | layout2 := record.NewLayoutSchema(sche2) 66 | us2, err := record.NewTableScan(tx, "T2", layout2) 67 | if err != nil { 68 | t.Fatal(err) 69 | } 70 | if err := us2.BeforeFirst(); err != nil { 71 | t.Fatal(err) 72 | } 73 | for i := 0; i < n; i++ { 74 | if err := us2.Insert(); err != nil { 75 | t.Fatal(err) 76 | } 77 | if err := us2.SetInt32("C", int32(i)); err != nil { 78 | t.Fatal(err) 79 | } 80 | if err := us2.SetStr("D", fmt.Sprintf("ddd%d", i)); err != nil { 81 | t.Fatal(err) 82 | } 83 | } 84 | if err := us2.Close(); err != nil { 85 | t.Fatal(err) 86 | } 87 | 88 | tablePlan1, err := plan.NewTablePlan("T1", tx, db.MetadataMgr()) 89 | if err != nil { 90 | t.Fatal(err) 91 | } 92 | tablePlan2, err := plan.NewTablePlan("T2", tx, db.MetadataMgr()) 93 | if err != nil { 94 | t.Fatal(err) 95 | } 96 | prodPlan := plan.NewMultiBufferProductPlan(tx, tablePlan1, tablePlan2) 97 | selectPlan := plan.NewSelectPlan(prodPlan, query.NewPredicate(query.NewTerm(schema.FieldName("A"), schema.FieldName("C")))) 98 | projectPlan := plan.NewProjectPlan(selectPlan, []schema.FieldName{"B", "D"}) 99 | 100 | s, err := projectPlan.Open() 101 | if err != nil { 102 | t.Fatal(err) 103 | } 104 | 105 | got := make([]string, 0, n*n) 106 | for { 107 | ok, err := s.Next() 108 | if err != nil { 109 | t.Fatal(err) 110 | } 111 | if !ok { 112 | break 113 | } 114 | b, err := s.Str("B") 115 | if err != nil { 116 | t.Fatal(err) 117 | } 118 | d, err := s.Str("D") 119 | if err != nil { 120 | t.Fatal(err) 121 | } 122 | got = append(got, fmt.Sprintf("%s, %s", b, d)) 123 | } 124 | if len(got) != n { 125 | t.Errorf("got %d, want %d", len(got), n) 126 | } 127 | expected := `bbb0, ddd0 128 | bbb1, ddd1 129 | bbb2, ddd2 130 | bbb3, ddd3 131 | bbb4, ddd4` 132 | if strings.Join(got, "\n") != expected { 133 | t.Errorf("got %s, want %s", strings.Join(got, "\n"), expected) 134 | } 135 | if err := s.Close(); err != nil { 136 | t.Fatal(err) 137 | } 138 | if err := tx.Commit(); err != nil { 139 | t.Fatal(err) 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /internal/plan/planner_heuristic_test.go: -------------------------------------------------------------------------------- 1 | package plan_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/abekoh/simple-db/internal/plan" 8 | "github.com/abekoh/simple-db/internal/simpledb" 9 | "github.com/abekoh/simple-db/internal/testdata" 10 | "github.com/abekoh/simple-db/internal/transaction" 11 | "github.com/google/go-cmp/cmp" 12 | ) 13 | 14 | func TestHeuristicQueryPlanner_QueryPlans(t *testing.T) { 15 | type test struct { 16 | name string 17 | snapshot string 18 | query string 19 | planInfo plan.Info 20 | } 21 | for _, tt := range []test{ 22 | { 23 | name: "one table, no index", 24 | snapshot: "tables_data", 25 | query: "SELECT name FROM students WHERE student_id = 200588", 26 | planInfo: plan.Info{ 27 | NodeType: "Project", 28 | Conditions: map[string][]string{"fields": {"name"}}, 29 | BlockAccessed: 770, 30 | RecordsOutput: 2, 31 | Children: []plan.Info{ 32 | { 33 | NodeType: "Select", 34 | Conditions: map[string][]string{"predicate": {"student_id=200588"}}, 35 | BlockAccessed: 770, 36 | RecordsOutput: 2, 37 | Children: []plan.Info{ 38 | { 39 | NodeType: "Table", 40 | Conditions: map[string][]string{"table": {"students"}}, 41 | BlockAccessed: 770, 42 | RecordsOutput: 10000, 43 | }, 44 | }, 45 | }, 46 | }, 47 | }, 48 | }, 49 | { 50 | name: "one table, use index", 51 | snapshot: "tables_indexes_data", 52 | query: "SELECT name FROM students WHERE student_id = 200588", 53 | planInfo: plan.Info{ 54 | NodeType: "Project", 55 | Conditions: map[string][]string{"fields": {"name"}}, 56 | BlockAccessed: 2, 57 | RecordsOutput: 2, 58 | Children: []plan.Info{ 59 | { 60 | NodeType: "Select", 61 | Conditions: map[string][]string{"predicate": {"student_id=200588"}}, 62 | BlockAccessed: 2, 63 | RecordsOutput: 2, 64 | Children: []plan.Info{ 65 | { 66 | NodeType: "IndexSelect", 67 | Conditions: map[string][]string{"index": {"students_pkey"}, "value": {"200588"}}, 68 | BlockAccessed: 2, 69 | RecordsOutput: 2, 70 | Children: []plan.Info{ 71 | { 72 | NodeType: "Table", 73 | Conditions: map[string][]string{"table": {"students"}}, 74 | BlockAccessed: 770, 75 | RecordsOutput: 10000, 76 | }, 77 | }, 78 | }, 79 | }, 80 | }, 81 | }, 82 | }, 83 | }, 84 | { 85 | name: "join two tables, no index", 86 | snapshot: "tables_data", 87 | query: "SELECT name, department_name FROM students JOIN departments ON major_id = department_id WHERE student_id = 200588", 88 | planInfo: plan.Info{ 89 | NodeType: "Project", 90 | Conditions: map[string][]string{"fields": {"name", "department_name"}}, 91 | BlockAccessed: 776, 92 | RecordsOutput: 0, 93 | Children: []plan.Info{ 94 | { 95 | NodeType: "Select", 96 | Conditions: map[string][]string{"predicate": {"major_id=department_id"}}, 97 | BlockAccessed: 776, 98 | RecordsOutput: 0, 99 | Children: []plan.Info{ 100 | { 101 | NodeType: "MultiBufferProduct", 102 | BlockAccessed: 776, 103 | RecordsOutput: 200, 104 | Children: []plan.Info{ 105 | { 106 | NodeType: "Table", 107 | Conditions: map[string][]string{"table": {"departments"}}, 108 | BlockAccessed: 6, 109 | RecordsOutput: 100, 110 | }, 111 | { 112 | NodeType: "Select", 113 | Conditions: map[string][]string{"predicate": {"student_id=200588"}}, 114 | BlockAccessed: 770, 115 | RecordsOutput: 2, 116 | Children: []plan.Info{ 117 | { 118 | NodeType: "Table", 119 | Conditions: map[string][]string{"table": {"students"}}, 120 | BlockAccessed: 770, 121 | RecordsOutput: 10000, 122 | }, 123 | }, 124 | }, 125 | }, 126 | }, 127 | }, 128 | }, 129 | }, 130 | }, 131 | }, 132 | { 133 | name: "join two tables, use index", 134 | snapshot: "tables_indexes_data", 135 | query: "SELECT name, department_name FROM students JOIN departments ON major_id = department_id WHERE student_id = 200588", 136 | planInfo: plan.Info{ 137 | NodeType: "Project", 138 | Conditions: map[string][]string{"fields": {"name", "department_name"}}, 139 | BlockAccessed: 8, 140 | RecordsOutput: 0, 141 | Children: []plan.Info{ 142 | { 143 | NodeType: "Select", 144 | Conditions: map[string][]string{"predicate": {"major_id=department_id"}}, 145 | BlockAccessed: 8, 146 | RecordsOutput: 0, 147 | Children: []plan.Info{ 148 | { 149 | NodeType: "IndexJoin", 150 | Conditions: map[string][]string{"index": {"departments_pkey"}, "field": {"major_id"}}, 151 | BlockAccessed: 8, 152 | RecordsOutput: 4, 153 | Children: []plan.Info{ 154 | { 155 | NodeType: "Select", 156 | Conditions: map[string][]string{"predicate": {"student_id=200588"}}, 157 | BlockAccessed: 2, 158 | RecordsOutput: 2, 159 | Children: []plan.Info{ 160 | { 161 | NodeType: "IndexSelect", 162 | Conditions: map[string][]string{"index": {"students_pkey"}, "value": {"200588"}}, 163 | BlockAccessed: 2, 164 | RecordsOutput: 2, 165 | Children: []plan.Info{ 166 | { 167 | NodeType: "Table", 168 | Conditions: map[string][]string{"table": {"students"}}, 169 | BlockAccessed: 770, 170 | RecordsOutput: 10000, 171 | }, 172 | }, 173 | }, 174 | }, 175 | }, 176 | { 177 | NodeType: "Table", 178 | Conditions: map[string][]string{"table": {"departments"}}, 179 | BlockAccessed: 6, 180 | RecordsOutput: 100, 181 | }, 182 | }, 183 | }, 184 | }, 185 | }, 186 | }, 187 | }, 188 | }, 189 | { 190 | name: "sort", 191 | snapshot: "tables_data", 192 | query: "SELECT name FROM students ORDER BY name", 193 | planInfo: plan.Info{ 194 | NodeType: "Project", 195 | Conditions: map[string][]string{"fields": {"name"}}, 196 | BlockAccessed: 750, 197 | RecordsOutput: 10000, 198 | Children: []plan.Info{ 199 | { 200 | NodeType: "Sort", 201 | Conditions: map[string][]string{"sortFields": {"name"}}, 202 | BlockAccessed: 750, 203 | RecordsOutput: 10000, 204 | Children: []plan.Info{ 205 | { 206 | NodeType: "Table", 207 | Conditions: map[string][]string{"table": {"students"}}, 208 | BlockAccessed: 770, 209 | RecordsOutput: 10000, 210 | }, 211 | }, 212 | }, 213 | }, 214 | }, 215 | }, 216 | { 217 | name: "group by (count)", 218 | snapshot: "tables_data", 219 | query: "SELECT grad_year, COUNT(*) AS cnt FROM students GROUP BY grad_year", 220 | planInfo: plan.Info{ 221 | NodeType: "Project", 222 | Conditions: map[string][]string{"fields": {"grad_year", "cnt"}}, 223 | BlockAccessed: 750, 224 | RecordsOutput: 3334, 225 | Children: []plan.Info{ 226 | { 227 | NodeType: "GroupBy", 228 | Conditions: map[string][]string{"aggregationFuncs": {"COUNT(*) AS cnt"}, "groupFields": {"grad_year"}}, 229 | BlockAccessed: 750, 230 | RecordsOutput: 3334, 231 | Children: []plan.Info{ 232 | { 233 | NodeType: "Sort", 234 | Conditions: map[string][]string{"sortFields": {"grad_year"}}, 235 | BlockAccessed: 750, 236 | RecordsOutput: 10000, 237 | Children: []plan.Info{ 238 | { 239 | NodeType: "Table", 240 | Conditions: map[string][]string{"table": {"students"}}, 241 | BlockAccessed: 770, 242 | RecordsOutput: 10000, 243 | }, 244 | }, 245 | }, 246 | }, 247 | }, 248 | }, 249 | }, 250 | }, 251 | { 252 | name: "group by (max)", 253 | snapshot: "tables_data", 254 | query: "SELECT major_id, MAX(grad_year) AS max_grad_year FROM students GROUP BY major_id", 255 | planInfo: plan.Info{ 256 | NodeType: "Project", 257 | Conditions: map[string][]string{"fields": {"major_id", "max_grad_year"}}, 258 | BlockAccessed: 750, 259 | RecordsOutput: 3334, 260 | Children: []plan.Info{ 261 | { 262 | NodeType: "GroupBy", 263 | Conditions: map[string][]string{"aggregationFuncs": {"MAX(grad_year) AS max_grad_year"}, "groupFields": {"major_id"}}, 264 | BlockAccessed: 750, 265 | RecordsOutput: 3334, 266 | Children: []plan.Info{ 267 | { 268 | NodeType: "Sort", 269 | Conditions: map[string][]string{"sortFields": {"major_id"}}, 270 | BlockAccessed: 750, 271 | RecordsOutput: 10000, 272 | Children: []plan.Info{ 273 | { 274 | NodeType: "Table", 275 | Conditions: map[string][]string{"table": {"students"}}, 276 | BlockAccessed: 770, 277 | RecordsOutput: 10000, 278 | }, 279 | }, 280 | }, 281 | }, 282 | }, 283 | }, 284 | }, 285 | }, 286 | } { 287 | t.Run(tt.name, func(t *testing.T) { 288 | transaction.CleanupLockTable(t) 289 | ctx := context.Background() 290 | dir := t.TempDir() 291 | if err := testdata.CopySnapshotData(tt.snapshot, dir); err != nil { 292 | t.Fatal(err) 293 | } 294 | db, err := simpledb.New(ctx, dir) 295 | if err != nil { 296 | t.Fatal(err) 297 | } 298 | tx, err := db.NewTx(ctx) 299 | if err != nil { 300 | t.Fatal(err) 301 | } 302 | res, err := db.Planner().Execute(tt.query, tx) 303 | if err != nil { 304 | t.Fatal(err) 305 | } 306 | p, ok := res.(plan.Plan) 307 | if !ok { 308 | t.Fatalf("unexpected type %T", res) 309 | } 310 | if diff := cmp.Diff(tt.planInfo, p.Info()); diff != "" { 311 | t.Errorf("(-got, +expected)\n%s", diff) 312 | } 313 | }) 314 | } 315 | } 316 | -------------------------------------------------------------------------------- /internal/plan/sort.go: -------------------------------------------------------------------------------- 1 | package plan 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/abekoh/simple-db/internal/query" 7 | "github.com/abekoh/simple-db/internal/record/schema" 8 | "github.com/abekoh/simple-db/internal/statement" 9 | "github.com/abekoh/simple-db/internal/transaction" 10 | ) 11 | 12 | type SortPlan struct { 13 | tx *transaction.Transaction 14 | p Plan 15 | sche schema.Schema 16 | sortFields []schema.FieldName 17 | order query.Order 18 | } 19 | 20 | var _ Plan = (*SortPlan)(nil) 21 | 22 | func NewSortPlan(tx *transaction.Transaction, p Plan, order query.Order) *SortPlan { 23 | sortFields := make([]schema.FieldName, 0) 24 | for _, el := range order { 25 | sortFields = append(sortFields, el.Field) 26 | } 27 | return &SortPlan{tx: tx, p: p, sche: *p.Schema(), sortFields: sortFields, order: order} 28 | } 29 | 30 | func (s SortPlan) Result() {} 31 | 32 | func (s SortPlan) Info() Info { 33 | conditions := make(map[string][]string) 34 | for _, fld := range s.sortFields { 35 | conditions["sortFields"] = append(conditions["sortFields"], string(fld)) 36 | } 37 | return Info{ 38 | NodeType: "Sort", 39 | Conditions: conditions, 40 | BlockAccessed: s.BlockAccessed(), 41 | RecordsOutput: s.RecordsOutput(), 42 | Children: []Info{s.p.Info()}, 43 | } 44 | } 45 | 46 | func (s SortPlan) Placeholders(findSchema func(tableName string) (*schema.Schema, error)) map[int]schema.FieldType { 47 | return s.p.Placeholders(findSchema) 48 | } 49 | 50 | func (s SortPlan) SwapParams(params map[int]schema.Constant) (statement.Bound, error) { 51 | bound, err := s.p.SwapParams(params) 52 | if err != nil { 53 | return nil, err 54 | } 55 | bp, ok := bound.(*BoundPlan) 56 | if !ok { 57 | return nil, fmt.Errorf("bound.(*plan.BoundPlan) error") 58 | } 59 | return &BoundPlan{ 60 | Plan: NewSortPlan(s.tx, bp, s.order), 61 | }, nil 62 | } 63 | 64 | func (s SortPlan) Open() (query.Scan, error) { 65 | src, err := s.p.Open() 66 | if err != nil { 67 | return nil, fmt.Errorf("p.Open error: %w", err) 68 | } 69 | 70 | // split the source into sorted runs 71 | var runs []TempTable 72 | if err := src.BeforeFirst(); err != nil { 73 | return nil, fmt.Errorf("src.BeforeFirst error: %w", err) 74 | } 75 | ok, err := src.Next() 76 | if err != nil { 77 | return nil, fmt.Errorf("src.Next error: %w", err) 78 | } 79 | if ok { 80 | currentTemp := NewTempTable(s.tx, s.sche) 81 | runs = append(runs, *currentTemp) 82 | currentScan, err := currentTemp.Open() 83 | if err != nil { 84 | return nil, fmt.Errorf("currentTemp.Open error: %w", err) 85 | } 86 | for { 87 | ok, err := s.copy(src, currentScan) 88 | if err != nil { 89 | return nil, fmt.Errorf("copy error: %w", err) 90 | } 91 | if !ok { 92 | break 93 | } 94 | cmpRes, err := s.order.Compare(src, currentScan) 95 | if err != nil { 96 | return nil, fmt.Errorf("NewOrder.Compare error: %w", err) 97 | } 98 | if cmpRes < 0 { 99 | if err := currentScan.Close(); err != nil { 100 | return nil, fmt.Errorf("currentScan.Close error: %w", err) 101 | } 102 | currentTemp = NewTempTable(s.tx, s.sche) 103 | runs = append(runs, *currentTemp) 104 | currentScan, err = currentTemp.Open() 105 | if err != nil { 106 | return nil, fmt.Errorf("currentTemp.Open error: %w", err) 107 | } 108 | } 109 | } 110 | if err := currentScan.Close(); err != nil { 111 | return nil, fmt.Errorf("currentScan.Close error: %w", err) 112 | } 113 | } 114 | if err := src.Close(); err != nil { 115 | return nil, fmt.Errorf("src.Close error: %w", err) 116 | } 117 | for { 118 | if len(runs) <= 2 { 119 | break 120 | } 121 | // do a merge iteration 122 | newRuns := make([]TempTable, 0) 123 | for len(runs) > 1 { 124 | p1 := runs[0] 125 | p2 := runs[1] 126 | runs = runs[2:] 127 | // merge two runs 128 | src1, err := p1.Open() 129 | if err != nil { 130 | return nil, fmt.Errorf("p1.Open error: %w", err) 131 | } 132 | src2, err := p2.Open() 133 | if err != nil { 134 | return nil, fmt.Errorf("p2.Open error: %w", err) 135 | } 136 | res := NewTempTable(s.tx, s.sche) 137 | dest, err := res.Open() 138 | if err != nil { 139 | return nil, fmt.Errorf("res.Open error: %w", err) 140 | } 141 | ok1, err := src1.Next() 142 | if err != nil { 143 | return nil, fmt.Errorf("src1.Next error: %w", err) 144 | } 145 | ok2, err := src2.Next() 146 | if err != nil { 147 | return nil, fmt.Errorf("src2.Next error: %w", err) 148 | } 149 | for ok1 && ok2 { 150 | cmpRes, err := s.order.Compare(src1, src2) 151 | if err != nil { 152 | return nil, fmt.Errorf("NewOrder.Compare error: %w", err) 153 | } 154 | if cmpRes < 0 { 155 | ok1, err = s.copy(src1, dest) 156 | if err != nil { 157 | return nil, fmt.Errorf("copy error: %w", err) 158 | } 159 | } else { 160 | ok2, err = s.copy(src2, dest) 161 | if err != nil { 162 | return nil, fmt.Errorf("copy error: %w", err) 163 | } 164 | } 165 | } 166 | if ok1 { 167 | for ok1 { 168 | ok1, err = s.copy(src1, dest) 169 | if err != nil { 170 | return nil, fmt.Errorf("copy error: %w", err) 171 | } 172 | } 173 | } else { 174 | for ok2 { 175 | ok2, err = s.copy(src2, dest) 176 | if err != nil { 177 | return nil, fmt.Errorf("copy error: %w", err) 178 | } 179 | } 180 | } 181 | if err := src1.Close(); err != nil { 182 | return nil, fmt.Errorf("src1.Close error: %w", err) 183 | } 184 | if err := src2.Close(); err != nil { 185 | return nil, fmt.Errorf("src2.Close error: %w", err) 186 | } 187 | if err := dest.Close(); err != nil { 188 | return nil, fmt.Errorf("dest.Close error: %w", err) 189 | } 190 | newRuns = append(newRuns, *res) 191 | } 192 | if len(runs) == 1 { 193 | newRuns = append(newRuns, runs[0]) 194 | } 195 | runs = newRuns 196 | } 197 | newScan, err := NewSortScan(runs, s.order) 198 | if err != nil { 199 | return nil, fmt.Errorf("NewSortScan error: %w", err) 200 | } 201 | return newScan, nil 202 | } 203 | 204 | func (s SortPlan) BlockAccessed() int { 205 | return NewMaterializePlan(s.tx, s.p).BlockAccessed() 206 | } 207 | 208 | func (s SortPlan) RecordsOutput() int { 209 | return s.p.RecordsOutput() 210 | } 211 | 212 | func (s SortPlan) DistinctValues(fieldName schema.FieldName) int { 213 | return s.p.DistinctValues(fieldName) 214 | } 215 | 216 | func (s SortPlan) Schema() *schema.Schema { 217 | return &s.sche 218 | } 219 | 220 | func (s SortPlan) copy(src query.Scan, dest query.UpdateScan) (bool, error) { 221 | if err := dest.Insert(); err != nil { 222 | return false, fmt.Errorf("dest.Insert error: %w", err) 223 | } 224 | for _, fldName := range s.sche.FieldNames() { 225 | val, err := src.Val(fldName) 226 | if err != nil { 227 | return false, fmt.Errorf("src.Val error: %w", err) 228 | } 229 | if err := dest.SetVal(fldName, val); err != nil { 230 | return false, fmt.Errorf("dest.SetVal error: %w", err) 231 | } 232 | } 233 | ok, err := src.Next() 234 | if err != nil { 235 | return false, fmt.Errorf("src.Next error: %w", err) 236 | } 237 | return ok, nil 238 | } 239 | 240 | type SortScan struct { 241 | s1, s2, currentScan query.UpdateScan 242 | order query.Order 243 | hasMore1, hasMore2 bool 244 | savedPosition1, savedPosition2 *schema.RID 245 | } 246 | 247 | var _ query.Scan = (*SortScan)(nil) 248 | 249 | func NewSortScan(runs []TempTable, order query.Order) (*SortScan, error) { 250 | if len(runs) == 0 || len(runs) > 2 { 251 | return nil, fmt.Errorf("runs length error") 252 | } 253 | s1, err := runs[0].Open() 254 | if err != nil { 255 | return nil, fmt.Errorf("runs[0].Open error: %w", err) 256 | } 257 | var s2 query.UpdateScan 258 | if len(runs) == 2 { 259 | s2, err = runs[1].Open() 260 | if err != nil { 261 | return nil, fmt.Errorf("runs[1].Open error: %w", err) 262 | } 263 | } 264 | hasMore1, err := s1.Next() 265 | if err != nil { 266 | return nil, fmt.Errorf("s1.Next error: %w", err) 267 | } 268 | var hasMore2 bool 269 | if s2 != nil { 270 | ok, err := s2.Next() 271 | if err != nil { 272 | return nil, fmt.Errorf("s2.Next error: %w", err) 273 | } 274 | hasMore2 = ok 275 | } 276 | return &SortScan{ 277 | s1: s1, 278 | s2: s2, 279 | order: order, 280 | hasMore1: hasMore1, 281 | hasMore2: hasMore2, 282 | }, nil 283 | } 284 | 285 | func (s *SortScan) Val(fieldName schema.FieldName) (schema.Constant, error) { 286 | if s.currentScan == nil { 287 | return nil, fmt.Errorf("currentScan is nil") 288 | } 289 | return s.currentScan.Val(fieldName) 290 | } 291 | 292 | func (s *SortScan) BeforeFirst() error { 293 | s.currentScan = nil 294 | if err := s.s1.BeforeFirst(); err != nil { 295 | return fmt.Errorf("s1.BeforeFirst error: %w", err) 296 | } 297 | hasMore1, err := s.s1.Next() 298 | if err != nil { 299 | return fmt.Errorf("s1.Next error: %w", err) 300 | } 301 | s.hasMore1 = hasMore1 302 | if s.s2 != nil { 303 | if err := s.s2.BeforeFirst(); err != nil { 304 | return fmt.Errorf("s2.BeforeFirst error: %w", err) 305 | } 306 | hasMore2, err := s.s2.Next() 307 | if err != nil { 308 | return fmt.Errorf("s2.Next error: %w", err) 309 | } 310 | s.hasMore2 = hasMore2 311 | } 312 | return nil 313 | } 314 | 315 | func (s *SortScan) Next() (bool, error) { 316 | if s.currentScan != nil { 317 | if s.currentScan == s.s1 { 318 | hasMore1, err := s.s1.Next() 319 | if err != nil { 320 | return false, fmt.Errorf("s1.Next error: %w", err) 321 | } 322 | s.hasMore1 = hasMore1 323 | } else if s.currentScan == s.s2 { 324 | hasMore2, err := s.s2.Next() 325 | if err != nil { 326 | return false, fmt.Errorf("s2.Next error: %w", err) 327 | } 328 | s.hasMore2 = hasMore2 329 | } 330 | } 331 | if !s.hasMore1 && !s.hasMore2 { 332 | return false, nil 333 | } else if s.hasMore1 && s.hasMore2 { 334 | cmpRes, err := s.order.Compare(s.s1, s.s2) 335 | if err != nil { 336 | return false, fmt.Errorf("comparator.Compare error: %w", err) 337 | } 338 | if cmpRes < 0 { 339 | s.currentScan = s.s1 340 | } else { 341 | s.currentScan = s.s2 342 | } 343 | } else if s.hasMore1 { 344 | s.currentScan = s.s1 345 | } else { 346 | s.currentScan = s.s2 347 | } 348 | return true, nil 349 | } 350 | 351 | func (s *SortScan) Int32(fieldName schema.FieldName) (int32, error) { 352 | if s.currentScan == nil { 353 | return 0, fmt.Errorf("currentScan is nil") 354 | } 355 | v, err := s.currentScan.Int32(fieldName) 356 | if err != nil { 357 | return 0, fmt.Errorf("currentScan.Int32 error: %w", err) 358 | } 359 | return v, nil 360 | } 361 | 362 | func (s *SortScan) Str(fieldName schema.FieldName) (string, error) { 363 | if s.currentScan == nil { 364 | return "", fmt.Errorf("currentScan is nil") 365 | } 366 | v, err := s.currentScan.Str(fieldName) 367 | if err != nil { 368 | return "", fmt.Errorf("currentScan.Str error: %w", err) 369 | } 370 | return v, nil 371 | } 372 | 373 | func (s *SortScan) HasField(fieldName schema.FieldName) bool { 374 | if s.currentScan == nil { 375 | return false 376 | } 377 | return s.currentScan.HasField(fieldName) 378 | } 379 | 380 | func (s *SortScan) Close() error { 381 | if s.s1 != nil { 382 | if err := s.s1.Close(); err != nil { 383 | return fmt.Errorf("s1.Close error: %w", err) 384 | } 385 | } 386 | if s.s2 != nil { 387 | if err := s.s2.Close(); err != nil { 388 | return fmt.Errorf("s2.Close error: %w", err) 389 | } 390 | } 391 | return nil 392 | } 393 | 394 | func (s *SortScan) SavePosition() { 395 | rid1 := s.s1.RID() 396 | s.savedPosition1 = &rid1 397 | if s.s2 != nil { 398 | rid2 := s.s2.RID() 399 | s.savedPosition2 = &rid2 400 | } 401 | } 402 | 403 | func (s *SortScan) RestorePosition() error { 404 | if s.savedPosition1 != nil { 405 | if err := s.s1.MoveToRID(*s.savedPosition1); err != nil { 406 | return fmt.Errorf("s1.MoveToRID error: %w", err) 407 | } 408 | } 409 | if s.savedPosition2 != nil { 410 | if err := s.s2.MoveToRID(*s.savedPosition2); err != nil { 411 | return fmt.Errorf("s2.MoveToRID error: %w", err) 412 | } 413 | } 414 | return nil 415 | } 416 | -------------------------------------------------------------------------------- /internal/plan/sort_test.go: -------------------------------------------------------------------------------- 1 | package plan_test 2 | 3 | import ( 4 | "context" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/abekoh/simple-db/internal/plan" 9 | "github.com/abekoh/simple-db/internal/query" 10 | "github.com/abekoh/simple-db/internal/record" 11 | "github.com/abekoh/simple-db/internal/record/schema" 12 | "github.com/abekoh/simple-db/internal/simpledb" 13 | "github.com/abekoh/simple-db/internal/transaction" 14 | ) 15 | 16 | func TestSortPlan(t *testing.T) { 17 | transaction.CleanupLockTable(t) 18 | ctx := context.Background() 19 | db, err := simpledb.New(ctx, t.TempDir()) 20 | if err != nil { 21 | t.Fatal(err) 22 | } 23 | tx, err := db.NewTx(ctx) 24 | if err != nil { 25 | t.Fatal(err) 26 | } 27 | 28 | sche := schema.NewSchema() 29 | sche.AddInt32Field("A") 30 | sche.AddStrField("B", 9) 31 | if err := db.MetadataMgr().CreateTable("mytable", sche, tx); err != nil { 32 | t.Fatal(err) 33 | } 34 | 35 | layout := record.NewLayoutSchema(sche) 36 | updateScan, err := record.NewTableScan(tx, "mytable", layout) 37 | if err != nil { 38 | t.Fatal(err) 39 | } 40 | 41 | if err := updateScan.BeforeFirst(); err != nil { 42 | t.Fatal(err) 43 | } 44 | for _, v := range []string{"rec2", "rec5", "rec1", "rec4", "rec3"} { 45 | if err := updateScan.Insert(); err != nil { 46 | t.Fatal(err) 47 | } 48 | rec := schema.ConstantStr(v) 49 | if err := updateScan.SetVal("B", rec); err != nil { 50 | t.Fatal(err) 51 | } 52 | } 53 | if err := updateScan.Close(); err != nil { 54 | t.Fatal(err) 55 | } 56 | 57 | tablePlan, err := plan.NewTablePlan("mytable", tx, db.MetadataMgr()) 58 | if err != nil { 59 | t.Fatal(err) 60 | } 61 | sortPlan := plan.NewSortPlan(tx, tablePlan, query.Order{query.OrderElement{Field: "B", OrderType: query.Asc}}) 62 | projectPlan := plan.NewProjectPlan(sortPlan, []schema.FieldName{"B"}) 63 | queryScan, err := projectPlan.Open() 64 | if err != nil { 65 | t.Fatal(err) 66 | } 67 | if err := queryScan.BeforeFirst(); err != nil { 68 | t.Fatal(err) 69 | } 70 | res := make([]string, 0, 5) 71 | for { 72 | ok, err := queryScan.Next() 73 | if err != nil { 74 | t.Fatal(err) 75 | } 76 | if !ok { 77 | break 78 | } 79 | val, err := queryScan.Str("B") 80 | if err != nil { 81 | t.Fatal(err) 82 | } 83 | res = append(res, val) 84 | } 85 | if err := queryScan.Close(); err != nil { 86 | t.Fatal(err) 87 | } 88 | if !reflect.DeepEqual(res, []string{"rec1", "rec2", "rec3", "rec4", "rec5"}) { 89 | t.Fatalf("unexpected result: %v", res) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /internal/postgres/server.go: -------------------------------------------------------------------------------- 1 | package postgres 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "log/slog" 9 | "net" 10 | "os" 11 | 12 | "github.com/abekoh/simple-db/internal/simpledb" 13 | ) 14 | 15 | type Config struct { 16 | Dir string 17 | Address string 18 | } 19 | 20 | func RunServer(ctx context.Context, cfg Config) error { 21 | address := cfg.Address 22 | if address == "" { 23 | address = "127.0.0.1:45432" 24 | } 25 | listen, err := net.Listen("tcp", address) 26 | if err != nil { 27 | return fmt.Errorf("could not listen on %s: %w", address, err) 28 | } 29 | slog.InfoContext(ctx, "Listening", "addr", listen.Addr()) 30 | go func() { 31 | <-ctx.Done() 32 | listen.Close() 33 | }() 34 | 35 | dir := cfg.Dir 36 | if dir == "" { 37 | dir, err = os.MkdirTemp(os.TempDir(), "simpledb") 38 | if err != nil { 39 | return fmt.Errorf("could not create temp dir: %w", err) 40 | } 41 | } 42 | 43 | db, err := simpledb.New(ctx, dir) 44 | if err != nil { 45 | return fmt.Errorf("could not create SimpleDB: %w", err) 46 | } 47 | 48 | for { 49 | conn, err := listen.Accept() 50 | if err != nil { 51 | return fmt.Errorf("could not accept connection: %w", err) 52 | } 53 | slog.InfoContext(ctx, "Accepted connection", "remote_addr", conn.RemoteAddr()) 54 | 55 | b := NewBackend(db, conn) 56 | 57 | go func() { 58 | err := b.Run() 59 | if err != nil { 60 | if !errors.Is(err, io.EOF) { 61 | panic(err) 62 | } 63 | } 64 | slog.InfoContext(ctx, "Closed connection", "remote_addr", conn.RemoteAddr()) 65 | }() 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /internal/postgres/server_bench_test.go: -------------------------------------------------------------------------------- 1 | package postgres 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | "testing" 7 | 8 | "github.com/abekoh/simple-db/internal/testdata" 9 | "github.com/abekoh/simple-db/internal/transaction" 10 | "github.com/jackc/pgx/v5" 11 | ) 12 | 13 | func BenchmarkPostgres_SelectOneRow(b *testing.B) { 14 | run := func(b *testing.B, srcDirname string) { 15 | b.Helper() 16 | 17 | slog.SetLogLoggerLevel(slog.LevelError) 18 | 19 | transaction.CleanupLockTable(b) 20 | ctx, cancel := context.WithCancel(context.Background()) 21 | b.Cleanup(cancel) 22 | dir := b.TempDir() 23 | if err := testdata.CopySnapshotData(srcDirname, dir); err != nil { 24 | b.Fatal(err) 25 | } 26 | cfg := Config{ 27 | Dir: dir, 28 | Address: "127.0.0.1:54329", 29 | } 30 | go func() { 31 | _ = RunServer(ctx, cfg) 32 | }() 33 | 34 | pgCfg, err := pgx.ParseConfig("postgres://postgres@127.0.0.1:54329/postgres") 35 | if err != nil { 36 | b.Fatal(err) 37 | } 38 | conn, err := pgx.ConnectConfig(ctx, pgCfg) 39 | if err != nil { 40 | b.Fatal(err) 41 | } 42 | 43 | studentIDs := []int{ 44 | 200001, 45 | 200376, 46 | 204199, 47 | 208321, 48 | 210000, 49 | } 50 | 51 | type Row struct { 52 | StudentID int 53 | Name string 54 | } 55 | 56 | b.ResetTimer() 57 | for range b.N { 58 | for _, studentID := range studentIDs { 59 | var r Row 60 | if err := conn.QueryRow(ctx, "SELECT student_id, name FROM students WHERE student_id = $1", studentID).Scan(&r.StudentID, &r.Name); err != nil { 61 | b.Fatal(err) 62 | } 63 | if r.StudentID != studentID { 64 | b.Errorf("unexpected student_id: %d", r.StudentID) 65 | } 66 | } 67 | } 68 | } 69 | b.Run("no index", func(b *testing.B) { 70 | run(b, "tables_data") 71 | }) 72 | b.Run("use index", func(b *testing.B) { 73 | run(b, "tables_indexes_data") 74 | }) 75 | } 76 | -------------------------------------------------------------------------------- /internal/postgres/server_test.go: -------------------------------------------------------------------------------- 1 | package postgres 2 | 3 | import ( 4 | "context" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/abekoh/simple-db/internal/transaction" 9 | "github.com/jackc/pgx/v5" 10 | ) 11 | 12 | func TestPostgres(t *testing.T) { 13 | transaction.CleanupLockTable(t) 14 | ctx, cancel := context.WithCancel(context.Background()) 15 | t.Cleanup(cancel) 16 | cfg := Config{ 17 | Dir: t.TempDir(), 18 | Address: "127.0.0.1:54329", 19 | } 20 | go func() { 21 | _ = RunServer(ctx, cfg) 22 | }() 23 | 24 | pgCfg, err := pgx.ParseConfig("postgres://postgres@127.0.0.1:54329/postgres") 25 | if err != nil { 26 | t.Fatal(err) 27 | } 28 | conn, err := pgx.ConnectConfig(ctx, pgCfg) 29 | if err != nil { 30 | t.Fatal(err) 31 | } 32 | 33 | tag, err := conn.Exec(ctx, "CREATE TABLE mytable (id INT, name VARCHAR(10))") 34 | if err != nil { 35 | t.Fatal(err) 36 | } 37 | if tag.String() != "CREATE TABLE" { 38 | t.Errorf("unexpected tag: %s", tag) 39 | } 40 | 41 | for _, args := range [][]any{ 42 | {1, "foo"}, 43 | {2, "bar"}, 44 | {3, "baz"}, 45 | } { 46 | tag, err = conn.Exec(ctx, "INSERT INTO mytable (id, name) VALUES ($1, $2)", args...) 47 | if err != nil { 48 | t.Fatal(err) 49 | } 50 | if tag.String() != "INSERT 0 1" { 51 | t.Errorf("unexpected tag: %s", tag) 52 | } 53 | } 54 | 55 | tag, err = conn.Exec(ctx, "UPDATE mytable SET name = 'HOGE' WHERE id = $1", 3) 56 | if err != nil { 57 | t.Fatal(err) 58 | } 59 | if tag.String() != "UPDATE 1" { 60 | t.Errorf("unexpected tag: %s", tag) 61 | } 62 | 63 | tag, err = conn.Exec(ctx, "DELETE FROM mytable WHERE id = $1", 2) 64 | if err != nil { 65 | t.Fatal(err) 66 | } 67 | if tag.String() != "DELETE 1" { 68 | t.Errorf("unexpected tag: %s", tag) 69 | } 70 | 71 | explainRows, err := conn.Query(ctx, "EXPLAIN SELECT id, name FROM mytable") 72 | if err != nil { 73 | t.Fatal(err) 74 | } 75 | defer explainRows.Close() 76 | type ExplainRow struct { 77 | QueryPlan string 78 | } 79 | resExplainRows := make([]ExplainRow, 0) 80 | for explainRows.Next() { 81 | var row ExplainRow 82 | if err := explainRows.Scan(&row.QueryPlan); err != nil { 83 | t.Fatal(err) 84 | } 85 | resExplainRows = append(resExplainRows, row) 86 | } 87 | if len(resExplainRows) != 1 { 88 | t.Fatalf("unexpected rows: %v", resExplainRows) 89 | } 90 | if resExplainRows[0].QueryPlan != `Project fields=id,name (ba=0,ro=0) 91 | Table table=mytable (ba=0,ro=0)` { 92 | t.Errorf("unexpected rows: %v", resExplainRows) 93 | } 94 | 95 | rows, err := conn.Query(ctx, "SELECT id, name FROM mytable") 96 | if err != nil { 97 | t.Fatal(err) 98 | } 99 | defer rows.Close() 100 | 101 | type Row struct { 102 | ID int32 103 | Name string 104 | } 105 | resRows := make([]Row, 0) 106 | for rows.Next() { 107 | var row Row 108 | if err := rows.Scan(&row.ID, &row.Name); err != nil { 109 | t.Fatal(err) 110 | } 111 | resRows = append(resRows, row) 112 | } 113 | if len(resRows) != 2 { 114 | t.Errorf("unexpected rows: %v", resRows) 115 | } 116 | if !reflect.DeepEqual(resRows, []Row{ 117 | {ID: 1, Name: "foo"}, 118 | {ID: 3, Name: "HOGE"}, 119 | }) { 120 | t.Errorf("unexpected rows: %v", resRows) 121 | } 122 | 123 | var queryRow Row 124 | if err := conn.QueryRow(ctx, "SELECT id, name FROM mytable WHERE id = $1", 3).Scan(&queryRow.ID, &queryRow.Name); err != nil { 125 | t.Fatal(err) 126 | } 127 | if !reflect.DeepEqual(queryRow, Row{ID: 3, Name: "HOGE"}) { 128 | t.Errorf("unexpected row: %v", queryRow) 129 | } 130 | } 131 | 132 | func TestPostgres_Transaction(t *testing.T) { 133 | transaction.CleanupLockTable(t) 134 | ctx, cancel := context.WithCancel(context.Background()) 135 | t.Cleanup(cancel) 136 | cfg := Config{ 137 | Dir: t.TempDir(), 138 | Address: "127.0.0.1:54329", 139 | } 140 | go func() { 141 | _ = RunServer(ctx, cfg) 142 | }() 143 | 144 | pgCfg, err := pgx.ParseConfig("postgres://postgres@127.0.0.1:54329/postgres") 145 | if err != nil { 146 | t.Fatal(err) 147 | } 148 | conn, err := pgx.ConnectConfig(ctx, pgCfg) 149 | if err != nil { 150 | t.Fatal(err) 151 | } 152 | 153 | tag, err := conn.Exec(ctx, "CREATE TABLE mytable (id INT, name VARCHAR(10))") 154 | if err != nil { 155 | t.Fatal(err) 156 | } 157 | if tag.String() != "CREATE TABLE" { 158 | t.Errorf("unexpected tag: %s", tag) 159 | } 160 | 161 | type Row struct { 162 | ID int32 163 | Name string 164 | } 165 | 166 | assertID1 := func(expected Row) { 167 | var queryRow Row 168 | if err := conn.QueryRow(ctx, "SELECT id, name FROM mytable WHERE id = $1", 1).Scan(&queryRow.ID, &queryRow.Name); err != nil { 169 | t.Fatal(err) 170 | } 171 | if !reflect.DeepEqual(queryRow, expected) { 172 | t.Errorf("unexpected row: %v", queryRow) 173 | } 174 | } 175 | 176 | tag, err = conn.Exec(ctx, "INSERT INTO mytable (id, name) VALUES ($1, $2)", 1, "foo") 177 | if err != nil { 178 | t.Fatal(err) 179 | } 180 | if tag.String() != "INSERT 0 1" { 181 | t.Errorf("unexpected tag: %s", tag) 182 | } 183 | assertID1(Row{ID: 1, Name: "foo"}) 184 | 185 | conn1, err := pgx.ConnectConfig(ctx, pgCfg) 186 | if err != nil { 187 | t.Fatal(err) 188 | } 189 | tx1, err := conn1.Begin(ctx) 190 | if err != nil { 191 | t.Fatal(err) 192 | } 193 | tag, err = tx1.Exec(ctx, "UPDATE mytable SET name = 'HOGE' WHERE id = $1", 1) 194 | if err != nil { 195 | t.Fatal(err) 196 | } 197 | if tag.String() != "UPDATE 1" { 198 | t.Errorf("unexpected tag: %s", tag) 199 | } 200 | err = tx1.Rollback(ctx) 201 | if err != nil { 202 | t.Fatal(err) 203 | } 204 | assertID1(Row{ID: 1, Name: "foo"}) 205 | 206 | conn2, err := pgx.ConnectConfig(ctx, pgCfg) 207 | if err != nil { 208 | t.Fatal(err) 209 | } 210 | tx2, err := conn2.Begin(ctx) 211 | if err != nil { 212 | t.Fatal(err) 213 | } 214 | tag, err = tx2.Exec(ctx, "UPDATE mytable SET name = 'HOGE' WHERE id = $1", 1) 215 | if err != nil { 216 | t.Fatal(err) 217 | } 218 | if tag.String() != "UPDATE 1" { 219 | t.Errorf("unexpected tag: %s", tag) 220 | } 221 | err = tx2.Commit(ctx) 222 | if err != nil { 223 | t.Fatal(err) 224 | } 225 | assertID1(Row{ID: 1, Name: "HOGE"}) 226 | } 227 | -------------------------------------------------------------------------------- /internal/query/query_test.go: -------------------------------------------------------------------------------- 1 | package query_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "reflect" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/abekoh/simple-db/internal/query" 11 | "github.com/abekoh/simple-db/internal/record" 12 | "github.com/abekoh/simple-db/internal/record/schema" 13 | "github.com/abekoh/simple-db/internal/simpledb" 14 | ) 15 | 16 | func TestProductScan(t *testing.T) { 17 | ctx := context.Background() 18 | db, err := simpledb.New(ctx, t.TempDir()) 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | tx, err := db.NewTx(ctx) 23 | if err != nil { 24 | t.Fatal(err) 25 | } 26 | 27 | sche1 := schema.NewSchema() 28 | sche1.AddInt32Field("A") 29 | sche1.AddStrField("B", 9) 30 | layout1 := record.NewLayoutSchema(sche1) 31 | ts1, err := record.NewTableScan(tx, "T1", layout1) 32 | if err != nil { 33 | t.Fatal(err) 34 | } 35 | 36 | sche2 := schema.NewSchema() 37 | sche2.AddInt32Field("C") 38 | sche2.AddStrField("D", 9) 39 | layout2 := record.NewLayoutSchema(sche2) 40 | ts2, err := record.NewTableScan(tx, "T2", layout2) 41 | if err != nil { 42 | t.Fatal(err) 43 | } 44 | 45 | if err := ts1.BeforeFirst(); err != nil { 46 | t.Fatal(err) 47 | } 48 | n := 5 49 | for i := 0; i < n; i++ { 50 | if err := ts1.Insert(); err != nil { 51 | t.Fatal(err) 52 | } 53 | if err := ts1.SetInt32("A", int32(i)); err != nil { 54 | t.Fatal(err) 55 | } 56 | if err := ts1.SetStr("B", fmt.Sprintf("aaa%d", i)); err != nil { 57 | t.Fatal(err) 58 | } 59 | } 60 | if err := ts1.Close(); err != nil { 61 | t.Fatal(err) 62 | } 63 | 64 | if err := ts2.BeforeFirst(); err != nil { 65 | t.Fatal(err) 66 | } 67 | for i := 0; i < n; i++ { 68 | if err := ts2.Insert(); err != nil { 69 | t.Fatal(err) 70 | } 71 | if err := ts2.SetInt32("C", int32(i)); err != nil { 72 | t.Fatal(err) 73 | } 74 | if err := ts2.SetStr("D", fmt.Sprintf("bbb%d", i)); err != nil { 75 | t.Fatal(err) 76 | } 77 | } 78 | if err := ts2.Close(); err != nil { 79 | t.Fatal(err) 80 | } 81 | 82 | ts1p, err := record.NewTableScan(tx, "T1", layout1) 83 | if err != nil { 84 | t.Fatal(err) 85 | } 86 | ts2p, err := record.NewTableScan(tx, "T2", layout2) 87 | if err != nil { 88 | t.Fatal(err) 89 | } 90 | ps, err := query.NewProductScan(ts1p, ts2p) 91 | if err != nil { 92 | t.Fatal(err) 93 | } 94 | got := make([]string, 0, n*n) 95 | for { 96 | ok, err := ps.Next() 97 | if err != nil { 98 | t.Fatal(err) 99 | } 100 | if !ok { 101 | break 102 | } 103 | a, err := ps.Int32("A") 104 | if err != nil { 105 | t.Fatal(err) 106 | } 107 | b, err := ps.Str("B") 108 | if err != nil { 109 | t.Fatal(err) 110 | } 111 | c, err := ps.Int32("C") 112 | if err != nil { 113 | t.Fatal(err) 114 | } 115 | d, err := ps.Str("D") 116 | if err != nil { 117 | t.Fatal(err) 118 | } 119 | got = append(got, fmt.Sprintf("{%d, %s, %d, %s}", a, b, c, d)) 120 | } 121 | if len(got) != n*n { 122 | t.Errorf("got %d, want %d", len(got), n*n) 123 | } 124 | expected := `{0, aaa0, 0, bbb0} 125 | {0, aaa0, 1, bbb1} 126 | {0, aaa0, 2, bbb2} 127 | {0, aaa0, 3, bbb3} 128 | {0, aaa0, 4, bbb4} 129 | {1, aaa1, 0, bbb0} 130 | {1, aaa1, 1, bbb1} 131 | {1, aaa1, 2, bbb2} 132 | {1, aaa1, 3, bbb3} 133 | {1, aaa1, 4, bbb4} 134 | {2, aaa2, 0, bbb0} 135 | {2, aaa2, 1, bbb1} 136 | {2, aaa2, 2, bbb2} 137 | {2, aaa2, 3, bbb3} 138 | {2, aaa2, 4, bbb4} 139 | {3, aaa3, 0, bbb0} 140 | {3, aaa3, 1, bbb1} 141 | {3, aaa3, 2, bbb2} 142 | {3, aaa3, 3, bbb3} 143 | {3, aaa3, 4, bbb4} 144 | {4, aaa4, 0, bbb0} 145 | {4, aaa4, 1, bbb1} 146 | {4, aaa4, 2, bbb2} 147 | {4, aaa4, 3, bbb3} 148 | {4, aaa4, 4, bbb4}` 149 | if strings.Join(got, "\n") != expected { 150 | t.Errorf("got %s, want %s", strings.Join(got, "\n"), expected) 151 | } 152 | if err := ps.Close(); err != nil { 153 | t.Fatal(err) 154 | } 155 | if err := tx.Commit(); err != nil { 156 | t.Fatal(err) 157 | } 158 | } 159 | 160 | func TestScan(t *testing.T) { 161 | t.Run("TableScan -> SelectScan -> ProjectScan", func(t *testing.T) { 162 | ctx := context.Background() 163 | db, err := simpledb.New(ctx, t.TempDir()) 164 | if err != nil { 165 | t.Fatal(err) 166 | } 167 | tx, err := db.NewTx(ctx) 168 | if err != nil { 169 | t.Fatal(err) 170 | } 171 | 172 | sche := schema.NewSchema() 173 | sche.AddInt32Field("A") 174 | sche.AddStrField("B", 9) 175 | layout := record.NewLayoutSchema(sche) 176 | scan1, err := record.NewTableScan(tx, "T", layout) 177 | if err != nil { 178 | t.Fatal(err) 179 | } 180 | 181 | if err := scan1.BeforeFirst(); err != nil { 182 | t.Fatal(err) 183 | } 184 | for i := 0; i < 200; i++ { 185 | if err := scan1.Insert(); err != nil { 186 | t.Fatal(err) 187 | } 188 | if err := scan1.SetInt32("A", int32(i/10)); err != nil { 189 | t.Fatal(err) 190 | } 191 | if err := scan1.SetStr("B", fmt.Sprintf("rec%d", i)); err != nil { 192 | t.Fatal(err) 193 | } 194 | } 195 | if err := scan1.Close(); err != nil { 196 | t.Fatal(err) 197 | } 198 | 199 | scan2, err := record.NewTableScan(tx, "T", layout) 200 | if err != nil { 201 | t.Fatal(err) 202 | } 203 | term := query.NewTerm(schema.FieldName("A"), schema.ConstantInt32(10)) 204 | pred := query.NewPredicate(term) 205 | if pred.String() != "A=10" { 206 | t.Fatalf("unexpected string: %s", pred.String()) 207 | } 208 | scan3 := query.NewSelectScan(scan2, pred) 209 | scan4 := query.NewProjectScan(scan3, "B") 210 | got := make([]string, 0, 10) 211 | for { 212 | ok, err := scan4.Next() 213 | if err != nil { 214 | t.Fatal(err) 215 | } 216 | if !ok { 217 | break 218 | } 219 | b, err := scan4.Str("B") 220 | if err != nil { 221 | t.Fatal(err) 222 | } 223 | got = append(got, b) 224 | } 225 | if len(got) != 10 { 226 | t.Errorf("got %d, want %d", len(got), 10) 227 | } 228 | expected := []string{"rec100", "rec101", "rec102", "rec103", "rec104", "rec105", "rec106", "rec107", "rec108", "rec109"} 229 | if !reflect.DeepEqual(got, expected) { 230 | t.Errorf("got %v, want %v", got, expected) 231 | } 232 | if err := scan4.Close(); err != nil { 233 | t.Fatal(err) 234 | } 235 | if err := tx.Commit(); err != nil { 236 | t.Fatal(err) 237 | } 238 | }) 239 | t.Run("TableScan*2 -> ProductScan -> SelectScan -> ProjectScan", func(t *testing.T) { 240 | ctx := context.Background() 241 | db, err := simpledb.New(ctx, t.TempDir()) 242 | if err != nil { 243 | t.Fatal(err) 244 | } 245 | tx, err := db.NewTx(ctx) 246 | if err != nil { 247 | t.Fatal(err) 248 | } 249 | 250 | sche1 := schema.NewSchema() 251 | sche1.AddInt32Field("A") 252 | sche1.AddStrField("B", 9) 253 | layout1 := record.NewLayoutSchema(sche1) 254 | us1, err := record.NewTableScan(tx, "T1", layout1) 255 | if err != nil { 256 | t.Fatal(err) 257 | } 258 | if err := us1.BeforeFirst(); err != nil { 259 | t.Fatal(err) 260 | } 261 | n := 5 262 | for i := 0; i < n; i++ { 263 | if err := us1.Insert(); err != nil { 264 | t.Fatal(err) 265 | } 266 | if err := us1.SetInt32("A", int32(i)); err != nil { 267 | t.Fatal(err) 268 | } 269 | if err := us1.SetStr("B", fmt.Sprintf("bbb%d", i)); err != nil { 270 | t.Fatal(err) 271 | } 272 | } 273 | if err := us1.Close(); err != nil { 274 | t.Fatal(err) 275 | } 276 | 277 | sche2 := schema.NewSchema() 278 | sche2.AddInt32Field("C") 279 | sche2.AddStrField("D", 9) 280 | layout2 := record.NewLayoutSchema(sche2) 281 | us2, err := record.NewTableScan(tx, "T2", layout2) 282 | if err != nil { 283 | t.Fatal(err) 284 | } 285 | if err := us2.BeforeFirst(); err != nil { 286 | t.Fatal(err) 287 | } 288 | for i := 0; i < n; i++ { 289 | if err := us2.Insert(); err != nil { 290 | t.Fatal(err) 291 | } 292 | if err := us2.SetInt32("C", int32(i)); err != nil { 293 | t.Fatal(err) 294 | } 295 | if err := us2.SetStr("D", fmt.Sprintf("ddd%d", i)); err != nil { 296 | t.Fatal(err) 297 | } 298 | } 299 | if err := us2.Close(); err != nil { 300 | t.Fatal(err) 301 | } 302 | 303 | s1p, err := record.NewTableScan(tx, "T1", layout1) 304 | if err != nil { 305 | t.Fatal(err) 306 | } 307 | s2p, err := record.NewTableScan(tx, "T2", layout2) 308 | if err != nil { 309 | t.Fatal(err) 310 | } 311 | prodS, err := query.NewProductScan(s1p, s2p) 312 | if err != nil { 313 | t.Fatal(err) 314 | } 315 | term := query.NewTerm(schema.FieldName("A"), schema.FieldName("C")) 316 | pred := query.NewPredicate(term) 317 | ss := query.NewSelectScan(prodS, pred) 318 | prjS := query.NewProjectScan(ss, "B", "D") 319 | got := make([]string, 0, n*n) 320 | for { 321 | ok, err := prjS.Next() 322 | if err != nil { 323 | t.Fatal(err) 324 | } 325 | if !ok { 326 | break 327 | } 328 | b, err := prjS.Str("B") 329 | if err != nil { 330 | t.Fatal(err) 331 | } 332 | d, err := prjS.Str("D") 333 | if err != nil { 334 | t.Fatal(err) 335 | } 336 | got = append(got, fmt.Sprintf("%s, %s", b, d)) 337 | } 338 | if len(got) != n { 339 | t.Errorf("got %d, want %d", len(got), n) 340 | } 341 | expected := `bbb0, ddd0 342 | bbb1, ddd1 343 | bbb2, ddd2 344 | bbb3, ddd3 345 | bbb4, ddd4` 346 | if strings.Join(got, "\n") != expected { 347 | t.Errorf("got %s, want %s", strings.Join(got, "\n"), expected) 348 | } 349 | if err := prjS.Close(); err != nil { 350 | t.Fatal(err) 351 | } 352 | if err := tx.Commit(); err != nil { 353 | t.Fatal(err) 354 | } 355 | }) 356 | } 357 | -------------------------------------------------------------------------------- /internal/record/record_test.go: -------------------------------------------------------------------------------- 1 | package record 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "reflect" 7 | "testing" 8 | 9 | "github.com/abekoh/simple-db/internal/buffer" 10 | "github.com/abekoh/simple-db/internal/file" 11 | "github.com/abekoh/simple-db/internal/log" 12 | "github.com/abekoh/simple-db/internal/record/schema" 13 | "github.com/abekoh/simple-db/internal/transaction" 14 | ) 15 | 16 | func TestLayout(t *testing.T) { 17 | t.Parallel() 18 | 19 | s := schema.NewSchema() 20 | s.AddInt32Field("A") 21 | s.AddStrField("B", 9) 22 | l := NewLayoutSchema(s) 23 | if offset, ok := l.Offset("A"); !ok || offset != 4 { 24 | t.Errorf("expected 0, got %d", offset) 25 | } 26 | if offset, ok := l.Offset("B"); !ok || offset != 8 { 27 | t.Errorf("expected 8, got %d", offset) 28 | } 29 | } 30 | 31 | func TestRecordPage(t *testing.T) { 32 | t.Parallel() 33 | 34 | fm, err := file.NewManager(t.TempDir(), 128) 35 | if err != nil { 36 | t.Fatal(err) 37 | } 38 | lm, err := log.NewManager(fm, "logfile") 39 | if err != nil { 40 | t.Fatal(err) 41 | } 42 | ctx := context.Background() 43 | bm := buffer.NewManager(ctx, fm, lm, 8) 44 | 45 | tx, err := transaction.NewTransaction(ctx, bm, fm, lm) 46 | if err != nil { 47 | t.Fatal(err) 48 | } 49 | 50 | schema := schema.NewSchema() 51 | schema.AddInt32Field("A") 52 | schema.AddStrField("B", 9) 53 | 54 | layout := NewLayoutSchema(schema) 55 | blockID, err := tx.Append("testfile") 56 | if err != nil { 57 | t.Fatal(err) 58 | } 59 | if _, err := tx.Pin(blockID); err != nil { 60 | t.Fatal(err) 61 | } 62 | rp, err := NewPage(tx, blockID, layout) 63 | if err != nil { 64 | t.Fatal(err) 65 | } 66 | if err := rp.Format(); err != nil { 67 | t.Fatal(err) 68 | } 69 | 70 | slot, ok, err := rp.InsertAfter(-1) 71 | if err != nil { 72 | t.Fatal(err) 73 | } 74 | for i := 0; ok; i++ { 75 | if err := rp.SetInt32(slot, "A", int32(i)); err != nil { 76 | t.Fatal(err) 77 | } 78 | s := fmt.Sprintf("rec%d", i) 79 | if err := rp.SetStr(slot, "B", s); err != nil { 80 | t.Fatal(err) 81 | } 82 | t.Logf("inserted at slot %d: A=%d, B=%s", slot, i, s) 83 | slot, ok, err = rp.InsertAfter(slot) 84 | if err != nil { 85 | t.Fatal(err) 86 | } 87 | } 88 | 89 | if err := rp.Delete(2); err != nil { 90 | t.Fatal(err) 91 | } 92 | if err := rp.Delete(4); err != nil { 93 | t.Fatal(err) 94 | } 95 | 96 | slot, ok, err = rp.NextAfter(-1) 97 | if err != nil { 98 | t.Fatal(err) 99 | } 100 | got := make([]string, 0) 101 | for ok { 102 | a, err := rp.Int32(slot, "A") 103 | if err != nil { 104 | t.Fatal(err) 105 | } 106 | b, err := rp.Str(slot, "B") 107 | if err != nil { 108 | t.Fatal(err) 109 | } 110 | got = append(got, fmt.Sprintf("slot %d: A=%d, B=%s", slot, a, b)) 111 | slot, ok, err = rp.NextAfter(slot) 112 | if err != nil { 113 | t.Fatal(err) 114 | } 115 | } 116 | expected := []string{ 117 | "slot 0: A=0, B=rec0", 118 | "slot 1: A=1, B=rec1", 119 | "slot 3: A=3, B=rec3", 120 | "slot 5: A=5, B=rec5", 121 | } 122 | if !reflect.DeepEqual(got, expected) { 123 | t.Errorf("expected %v, got %v", expected, got) 124 | } 125 | } 126 | 127 | func TestTableScan(t *testing.T) { 128 | t.Parallel() 129 | 130 | fm, err := file.NewManager(t.TempDir(), 128) 131 | if err != nil { 132 | t.Fatal(err) 133 | } 134 | lm, err := log.NewManager(fm, "logfile") 135 | if err != nil { 136 | t.Fatal(err) 137 | } 138 | ctx := context.Background() 139 | bm := buffer.NewManager(ctx, fm, lm, 8) 140 | 141 | tx, err := transaction.NewTransaction(ctx, bm, fm, lm) 142 | if err != nil { 143 | t.Fatal(err) 144 | } 145 | 146 | schema := schema.NewSchema() 147 | schema.AddInt32Field("A") 148 | schema.AddStrField("B", 9) 149 | 150 | layout := NewLayoutSchema(schema) 151 | 152 | ts, err := NewTableScan(tx, "T", layout) 153 | if err != nil { 154 | t.Fatal(err) 155 | } 156 | 157 | for i := 0; i < 10; i++ { 158 | if err := ts.Insert(); err != nil { 159 | t.Fatal(err) 160 | } 161 | if err := ts.SetInt32("A", int32(i)); err != nil { 162 | t.Fatal(err) 163 | } 164 | s := fmt.Sprintf("rec%d", i) 165 | if err := ts.SetStr("B", s); err != nil { 166 | t.Fatal(err) 167 | } 168 | t.Logf("inserted: %v, {%v, %v}", ts.RID(), i, s) 169 | } 170 | 171 | if err := ts.BeforeFirst(); err != nil { 172 | t.Fatal(err) 173 | } 174 | ok, err := ts.Next() 175 | if err != nil { 176 | t.Fatal(err) 177 | } 178 | for ok { 179 | a, err := ts.Int32("A") 180 | if err != nil { 181 | t.Fatal(err) 182 | } 183 | b, err := ts.Str("B") 184 | if err != nil { 185 | t.Fatal(err) 186 | } 187 | t.Logf("scanned: %v, {%v, %v}", ts.RID(), a, b) 188 | if a%2 == 0 { 189 | if err := ts.Delete(); err != nil { 190 | t.Fatal(err) 191 | } 192 | t.Logf("deleted: %v", ts.RID()) 193 | } 194 | ok, err = ts.Next() 195 | if err != nil { 196 | t.Fatal(err) 197 | } 198 | } 199 | 200 | if err := ts.BeforeFirst(); err != nil { 201 | t.Fatal(err) 202 | } 203 | ok, err = ts.Next() 204 | if err != nil { 205 | t.Fatal(err) 206 | } 207 | got := make([]string, 0) 208 | for ok { 209 | a, err := ts.Int32("A") 210 | if err != nil { 211 | t.Fatal(err) 212 | } 213 | b, err := ts.Str("B") 214 | if err != nil { 215 | t.Fatal(err) 216 | } 217 | got = append(got, fmt.Sprintf("%v, {%v, %v}", ts.RID(), a, b)) 218 | ok, err = ts.Next() 219 | if err != nil { 220 | t.Fatal(err) 221 | } 222 | } 223 | expected := []string{ 224 | "RID{blockNum=0, slot=1}, {1, rec1}", 225 | "RID{blockNum=0, slot=3}, {3, rec3}", 226 | "RID{blockNum=0, slot=5}, {5, rec5}", 227 | "RID{blockNum=1, slot=1}, {7, rec7}", 228 | "RID{blockNum=1, slot=3}, {9, rec9}", 229 | } 230 | if !reflect.DeepEqual(got, expected) { 231 | t.Errorf("expected %v, got %v", expected, got) 232 | } 233 | } 234 | -------------------------------------------------------------------------------- /internal/record/schema/schema.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "cmp" 5 | "errors" 6 | "fmt" 7 | ) 8 | 9 | type RID struct { 10 | blockNum int32 11 | slot int32 12 | } 13 | 14 | func NewRID(blockNum, slot int32) RID { 15 | return RID{blockNum: blockNum, slot: slot} 16 | } 17 | 18 | func (r RID) BlockNum() int32 { 19 | return r.blockNum 20 | } 21 | 22 | func (r RID) Slot() int32 { 23 | return r.slot 24 | } 25 | 26 | func (r RID) String() string { 27 | return fmt.Sprintf("RID{blockNum=%d, slot=%d}", r.blockNum, r.slot) 28 | } 29 | 30 | func (r RID) Equals(other RID) bool { 31 | return r.blockNum == other.blockNum && r.slot == other.slot 32 | } 33 | 34 | type FieldType int32 35 | 36 | const ( 37 | Integer32 FieldType = iota 38 | Varchar 39 | ) 40 | 41 | type Flag int32 42 | 43 | const ( 44 | Empty Flag = iota 45 | Used 46 | ) 47 | 48 | type Field struct { 49 | typ FieldType 50 | length int32 51 | } 52 | 53 | func NewField(typ FieldType, length int32) Field { 54 | return Field{typ: typ, length: length} 55 | } 56 | 57 | func NewInt32Field() Field { 58 | return Field{typ: Integer32, length: 0} 59 | } 60 | 61 | func NewVarcharField(length int32) Field { 62 | return Field{typ: Varchar, length: length} 63 | } 64 | 65 | type Schema struct { 66 | fields []FieldName 67 | fieldsMap map[FieldName]Field 68 | } 69 | 70 | func NewSchema() Schema { 71 | return Schema{ 72 | fields: make([]FieldName, 0), 73 | fieldsMap: make(map[FieldName]Field), 74 | } 75 | } 76 | 77 | func (s *Schema) AddField(name FieldName, f Field) { 78 | s.fields = append(s.fields, name) 79 | s.fieldsMap[name] = f 80 | } 81 | 82 | func (s *Schema) AddInt32Field(name FieldName) { 83 | s.AddField(name, NewInt32Field()) 84 | } 85 | 86 | func (s *Schema) AddStrField(name FieldName, length int32) { 87 | s.AddField(name, NewVarcharField(length)) 88 | } 89 | 90 | func (s *Schema) Add(name FieldName, schema Schema) { 91 | typ := schema.Typ(name) 92 | length := schema.Length(name) 93 | s.AddField(name, Field{typ: typ, length: length}) 94 | } 95 | 96 | func (s *Schema) AddAll(schema Schema) { 97 | for _, field := range schema.fields { 98 | f, ok := schema.fieldsMap[field] 99 | if !ok { 100 | panic("field not found") 101 | } 102 | s.AddField(field, f) 103 | } 104 | } 105 | 106 | func (s *Schema) FieldNames() []FieldName { 107 | names := make([]FieldName, 0, len(s.fields)) 108 | names = append(names, s.fields...) 109 | return names 110 | } 111 | 112 | func (s *Schema) HasField(name FieldName) bool { 113 | _, ok := s.fieldsMap[name] 114 | return ok 115 | } 116 | 117 | func (s *Schema) Typ(name FieldName) FieldType { 118 | return s.fieldsMap[name].typ 119 | } 120 | 121 | func (s *Schema) Length(name FieldName) int32 { 122 | return s.fieldsMap[name].length 123 | } 124 | 125 | type FieldName string 126 | 127 | func (f FieldName) Evaluate(v Valuable) (Constant, error) { 128 | return v.Val(f) 129 | } 130 | 131 | func (f FieldName) AppliesTo(s *Schema) bool { 132 | return s.HasField(f) 133 | } 134 | 135 | type Constant interface { 136 | fmt.Stringer 137 | Val() any 138 | HashCode() int 139 | Equals(Constant) bool 140 | Compare(Constant) int 141 | } 142 | 143 | type ConstantInt32 int32 144 | 145 | func (v ConstantInt32) String() string { 146 | return fmt.Sprintf("%d", v) 147 | } 148 | 149 | func (v ConstantInt32) Val() any { 150 | return int32(v) 151 | } 152 | 153 | func (v ConstantInt32) Evaluate(Valuable) (Constant, error) { 154 | return v, nil 155 | } 156 | 157 | func (v ConstantInt32) HashCode() int { 158 | return int(v) 159 | } 160 | 161 | func (v ConstantInt32) Equals(c Constant) bool { 162 | if c, ok := c.(ConstantInt32); ok { 163 | return v == c 164 | } 165 | return false 166 | } 167 | 168 | func (v ConstantInt32) Compare(c Constant) int { 169 | if c, ok := c.(ConstantInt32); ok { 170 | return cmp.Compare(int32(v), int32(c)) 171 | } 172 | panic("type mismatch") 173 | } 174 | 175 | func (v ConstantInt32) AppliesTo(_ *Schema) bool { 176 | return true 177 | } 178 | 179 | type ConstantStr string 180 | 181 | func (v ConstantStr) String() string { 182 | return string(v) 183 | } 184 | 185 | func (v ConstantStr) Val() any { 186 | return string(v) 187 | } 188 | 189 | func (v ConstantStr) Evaluate(Valuable) (Constant, error) { 190 | return v, nil 191 | } 192 | 193 | func (v ConstantStr) HashCode() int { 194 | h := 0 195 | for _, c := range v { 196 | h = h*31 + int(c) 197 | } 198 | return h 199 | } 200 | 201 | func (v ConstantStr) Equals(c Constant) bool { 202 | if c, ok := c.(ConstantStr); ok { 203 | return v == c 204 | } 205 | return false 206 | } 207 | 208 | func (v ConstantStr) Compare(c Constant) int { 209 | if c, ok := c.(ConstantStr); ok { 210 | return cmp.Compare(string(v), string(c)) 211 | } 212 | panic("type mismatch") 213 | } 214 | 215 | func (v ConstantStr) AppliesTo(_ *Schema) bool { 216 | return true 217 | } 218 | 219 | type Valuable interface { 220 | Val(fieldName FieldName) (Constant, error) 221 | } 222 | 223 | type Placeholder int 224 | 225 | func (p Placeholder) String() string { 226 | return fmt.Sprintf("$%d", p) 227 | } 228 | 229 | func (p Placeholder) Val() any { 230 | panic("don't use placeholder as value") 231 | } 232 | 233 | func (p Placeholder) Evaluate(v Valuable) (Constant, error) { 234 | return nil, fmt.Errorf("placeholder cannot be evaluated") 235 | } 236 | 237 | func (p Placeholder) HashCode() int { 238 | panic("don't use placeholder as value") 239 | } 240 | 241 | func (p Placeholder) Equals(Constant) bool { 242 | panic("don't use placeholder as value") 243 | } 244 | 245 | func (p Placeholder) Compare(Constant) int { 246 | panic("don't use placeholder as value") 247 | } 248 | 249 | func (p Placeholder) AppliesTo(_ *Schema) bool { 250 | return true 251 | } 252 | 253 | var ErrTypeAssertionFailed = errors.New("type assertion failed") 254 | -------------------------------------------------------------------------------- /internal/simpledb/simpledb.go: -------------------------------------------------------------------------------- 1 | package simpledb 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/abekoh/simple-db/internal/buffer" 8 | "github.com/abekoh/simple-db/internal/file" 9 | "github.com/abekoh/simple-db/internal/log" 10 | "github.com/abekoh/simple-db/internal/metadata" 11 | "github.com/abekoh/simple-db/internal/plan" 12 | "github.com/abekoh/simple-db/internal/statement" 13 | "github.com/abekoh/simple-db/internal/transaction" 14 | ) 15 | 16 | type DB struct { 17 | fileMgr *file.Manager 18 | bufMgr *buffer.Manager 19 | logMgr *log.Manager 20 | metadataMgr *metadata.Manager 21 | planner *plan.Planner 22 | stmtMgr *statement.Manager 23 | } 24 | 25 | func NewWithParams(ctx context.Context, dirname string, blockSize int32, bufSize int) (*DB, error) { 26 | fm, err := file.NewManager(dirname, blockSize) 27 | if err != nil { 28 | return nil, err 29 | } 30 | const logFileName = "simpledb.log" 31 | lm, err := log.NewManager(fm, logFileName) 32 | if err != nil { 33 | return nil, err 34 | } 35 | bm := buffer.NewManager(ctx, fm, lm, bufSize) 36 | return &DB{ 37 | fileMgr: fm, 38 | bufMgr: bm, 39 | logMgr: lm, 40 | }, nil 41 | } 42 | 43 | func NewWithConfig(ctx context.Context, dirname string, cfg *Config) (*DB, error) { 44 | if cfg == nil { 45 | cfg = &Config{} 46 | } 47 | db, err := NewWithParams(ctx, dirname, 400, 64) 48 | if err != nil { 49 | return nil, fmt.Errorf("could not create SimpleDB: %w", err) 50 | } 51 | tx, err := transaction.NewTransaction(ctx, db.bufMgr, db.fileMgr, db.logMgr) 52 | if err != nil { 53 | return nil, fmt.Errorf("could not create SimpleDB: %w", err) 54 | } 55 | isNew := db.fileMgr.IsNew() 56 | if !isNew { 57 | if err := tx.Recover(); err != nil { 58 | return nil, fmt.Errorf("could not recover: %w", err) 59 | } 60 | } 61 | metadataMgr, err := metadata.NewManager(isNew, tx, cfg.Metadata) 62 | if err != nil { 63 | return nil, fmt.Errorf("could not create SimpleDB: %w", err) 64 | 65 | } 66 | queryPlannerInitializer := plan.NewHeuristicQueryPlanner 67 | if cfg.Plan != nil && cfg.Plan.QueryPlannerInitializer != nil { 68 | queryPlannerInitializer = cfg.Plan.QueryPlannerInitializer 69 | } 70 | updatePlannerInitializer := plan.NewIndexUpdatePlanner 71 | if cfg.Plan != nil && cfg.Plan.UpdatePlannerInitializer != nil { 72 | updatePlannerInitializer = cfg.Plan.UpdatePlannerInitializer 73 | } 74 | db.metadataMgr = metadataMgr 75 | db.planner = plan.NewPlanner( 76 | queryPlannerInitializer(metadataMgr), 77 | updatePlannerInitializer(metadataMgr), 78 | metadataMgr, 79 | ) 80 | db.stmtMgr = statement.NewManager() 81 | if err := tx.Commit(); err != nil { 82 | return nil, fmt.Errorf("could not recover: %w", err) 83 | } 84 | return db, nil 85 | } 86 | 87 | func New(ctx context.Context, dirname string) (*DB, error) { 88 | return NewWithConfig(ctx, dirname, nil) 89 | } 90 | 91 | func (db DB) NewTx(ctx context.Context) (*transaction.Transaction, error) { 92 | return transaction.NewTransaction(ctx, db.bufMgr, db.fileMgr, db.logMgr) 93 | } 94 | 95 | func (db DB) MetadataMgr() *metadata.Manager { 96 | return db.metadataMgr 97 | } 98 | 99 | func (db DB) Planner() *plan.Planner { 100 | return db.planner 101 | } 102 | 103 | func (db DB) StmtMgr() *statement.Manager { 104 | return db.stmtMgr 105 | } 106 | 107 | type Config struct { 108 | Plan *plan.Config 109 | Metadata *metadata.Config 110 | } 111 | -------------------------------------------------------------------------------- /internal/statement/statement.go: -------------------------------------------------------------------------------- 1 | package statement 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/abekoh/simple-db/internal/record/schema" 7 | ) 8 | 9 | type Prepared interface { 10 | Placeholders(findSchema func(tableName string) (*schema.Schema, error)) map[int]schema.FieldType 11 | SwapParams(params map[int]schema.Constant) (Bound, error) 12 | } 13 | 14 | type Bound interface { 15 | Bound() 16 | } 17 | 18 | type Manager struct { 19 | statements map[string]Prepared 20 | } 21 | 22 | func NewManager() *Manager { 23 | return &Manager{ 24 | statements: make(map[string]Prepared), 25 | } 26 | } 27 | 28 | func (m *Manager) Add(name string, prepared Prepared) { 29 | m.statements[name] = prepared 30 | } 31 | 32 | func (m *Manager) Get(name string) (Prepared, error) { 33 | stmt, ok := m.statements[name] 34 | if !ok { 35 | return nil, fmt.Errorf("unknown statement: %s", name) 36 | } 37 | return stmt, nil 38 | } 39 | -------------------------------------------------------------------------------- /internal/testdata/create_indexes.sql: -------------------------------------------------------------------------------- 1 | CREATE INDEX departments_pkey ON departments (department_id); 2 | CREATE INDEX students_pkey ON students (student_id); 3 | CREATE INDEX courses_pkey ON courses (course_id); 4 | CREATE INDEX sections_pkey ON sections (section_id); 5 | CREATE INDEX students_major_department_id ON students (major_department_id); 6 | CREATE INDEX courses_course_department_id ON courses (course_department_id); 7 | CREATE INDEX sections_section_course_id ON sections (section_course_id); 8 | -------------------------------------------------------------------------------- /internal/testdata/create_tables.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE departments (department_id INT, department_name VARCHAR(10)); 2 | CREATE TABLE students (student_id INT, name VARCHAR(10), major_id INT, grad_year INT); 3 | CREATE TABLE courses(course_id INT, title VARCHAR(20), course_department_id INT); 4 | CREATE TABLE sections(section_id INT, section_course_id INT, professor VARCHAR(8), year_offered int) 5 | -------------------------------------------------------------------------------- /internal/testdata/example.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE departments (department_id INT, department_name VARCHAR(20)); 2 | CREATE TABLE students (student_id INT, name VARCHAR(10), major_id INT, grad_year INT); 3 | 4 | INSERT INTO students (student_id, name, major_id, grad_year) VALUES (1, 'Alice', 1, 2018); 5 | INSERT INTO students (student_id, name, major_id, grad_year) VALUES (2, 'Bob', 1, 2020); 6 | 7 | SELECT student_id, name FROM students; 8 | SELECT student_id, name FROM students WHERE student_id = 1; 9 | 10 | UPDATE students SET name = 'Adam' WHERE student_id = 1; 11 | SELECT student_id, name FROM students; 12 | 13 | DELETE FROM students WHERE student_id = 1; 14 | SELECT student_id, name FROM students; 15 | 16 | BEGIN; 17 | UPDATE students SET name = 'BOB' WHERE student_id = 2; 18 | SELECT student_id, name FROM students; 19 | ROLLBACK; 20 | SELECT student_id, name FROM students; 21 | 22 | 23 | INSERT INTO departments (department_id, department_name) VALUES (1, 'Computer Science'); 24 | INSERT INTO departments (department_id, department_name) VALUES (2, 'Mathematics'); 25 | INSERT INTO students (student_id, name, major_id, grad_year) VALUES (1, 'Alice', 1, 2018); 26 | INSERT INTO students (student_id, name, major_id, grad_year) VALUES (2, 'Bob', 1, 2020); 27 | INSERT INTO students (student_id, name, major_id, grad_year) VALUES (3, 'Charlie', 1, 2007); 28 | INSERT INTO students (student_id, name, major_id, grad_year) VALUES (4, 'David', 2, 2019); 29 | INSERT INTO students (student_id, name, major_id, grad_year) VALUES (5, 'Eve', 2, 1999); 30 | 31 | SELECT name, department_name FROM students JOIN departments ON major_id = department_id; 32 | 33 | EXPLAIN SELECT department_name FROM departments WHERE department_id = 10; 34 | CREATE INDEX departments_pkey ON departments (department_id); 35 | EXPLAIN SELECT department_name FROM departments WHERE department_id = 10; 36 | EXPLAIN SELECT department_name, MIN(grad_year) AS min_grad_year FROM students JOIN departments ON major_id = department_id GROUP BY department_name ORDER BY department_name; 37 | -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables/field_catalog.tbl: -------------------------------------------------------------------------------- 1 |  table_catalog 2 | table_name  table_catalog slot_size( field_catalog 3 | table_name  field_catalog 4 | field_name ( field_catalogtypeL field_cataloglengthP field_catalogoffsetT view_catalog view_name  view_catalogview_defd( index_catalog 5 | index_name  index_catalog 6 | table_name ( index_catalog 7 | field_name L departments department_id departmentsdepartment_name 8 | students 9 | student_idstudentsname 10 | studentsmajor_department_idstudents grad_yearcourses course_idcoursestitlecoursescourse_department_id sections 11 | section_idsectionssection_course_idsections professor sections year_offered -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables/simpledb.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables/simpledb.log -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables/table_catalog.tbl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables/table_catalog.tbl -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables/view_catalog.tbl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_data/courses.tbl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_data/courses.tbl -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_data/departments.tbl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_data/departments.tbl -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_data/field_catalog.tbl: -------------------------------------------------------------------------------- 1 |  table_catalog 2 | table_name  table_catalog slot_size( field_catalog 3 | table_name  field_catalog 4 | field_name ( field_catalogtypeL field_cataloglengthP field_catalogoffsetT view_catalog view_name  view_catalogview_defd( index_catalog 5 | index_name  index_catalog 6 | table_name ( index_catalog 7 | field_name L departments department_id departmentsdepartment_name 8 | students 9 | student_idstudentsname 10 | studentsmajor_idstudents grad_yearcourses course_idcoursestitlecoursescourse_department_id sections 11 | section_idsectionssection_course_idsections professor sections year_offered -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_data/index_catalog.tbl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_data/sections.tbl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_data/sections.tbl -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_data/simpledb.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_data/simpledb.log -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_data/students.tbl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_data/students.tbl -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_data/table_catalog.tbl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_data/table_catalog.tbl -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_data/view_catalog.tbl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/courses.tbl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_indexes_data/courses.tbl -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/courses_course_department_id_dir: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_indexes_data/courses_course_department_id_dir -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/courses_course_department_id_leaf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_indexes_data/courses_course_department_id_leaf -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/courses_pkey_dir: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_indexes_data/courses_pkey_dir -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/courses_pkey_leaf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_indexes_data/courses_pkey_leaf -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/departments.tbl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_indexes_data/departments.tbl -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/departments_pkey_dir: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_indexes_data/departments_pkey_dir -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/departments_pkey_leaf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_indexes_data/departments_pkey_leaf -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/field_catalog.tbl: -------------------------------------------------------------------------------- 1 |  table_catalog 2 | table_name  table_catalog slot_size( field_catalog 3 | table_name  field_catalog 4 | field_name ( field_catalogtypeL field_cataloglengthP field_catalogoffsetT view_catalog view_name  view_catalogview_defd( index_catalog 5 | index_name  index_catalog 6 | table_name ( index_catalog 7 | field_name L departments department_id departmentsdepartment_name 8 | students 9 | student_idstudentsname 10 | studentsmajor_idstudents grad_yearcourses course_idcoursestitlecoursescourse_department_id sections 11 | section_idsectionssection_course_idsections professor sections year_offered -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/index_catalog.tbl: -------------------------------------------------------------------------------- 1 | departments_pkey departments department_id students_pkeystudents 2 | student_id courses_pkeycourses course_id sections_pkeysections 3 | section_idstudents_major_department_idstudentsmajor_department_idcourses_course_department_idcoursescourse_department_idsections_section_course_idsectionssection_course_id -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/sections.tbl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_indexes_data/sections.tbl -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/sections_pkey_dir: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_indexes_data/sections_pkey_dir -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/sections_pkey_leaf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_indexes_data/sections_pkey_leaf -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/sections_section_course_id_dir: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_indexes_data/sections_section_course_id_dir -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/sections_section_course_id_leaf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_indexes_data/sections_section_course_id_leaf -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/simpledb.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_indexes_data/simpledb.log -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/students.tbl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_indexes_data/students.tbl -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/students_pkey_dir: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_indexes_data/students_pkey_dir -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/students_pkey_leaf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_indexes_data/students_pkey_leaf -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/table_catalog.tbl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abekoh/simple-db/801b5bb6bff0e1e8670f03ee00edd81a57e1c09f/internal/testdata/snapshots/tables_indexes_data/table_catalog.tbl -------------------------------------------------------------------------------- /internal/testdata/snapshots/tables_indexes_data/view_catalog.tbl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /internal/testdata/testdata.go: -------------------------------------------------------------------------------- 1 | package testdata 2 | 3 | import ( 4 | "bufio" 5 | "embed" 6 | "fmt" 7 | "io" 8 | "os" 9 | "path" 10 | "strings" 11 | ) 12 | 13 | //go:embed * 14 | var embedFiles embed.FS 15 | 16 | func SQLIterator(filenames ...string) func(func(string, error) bool) { 17 | return func(yield func(string, error) bool) { 18 | for _, filename := range filenames { 19 | f, err := embedFiles.Open(filename) 20 | if err != nil { 21 | if !yield("", err) { 22 | return 23 | } 24 | continue 25 | } 26 | scanner := bufio.NewScanner(f) 27 | for scanner.Scan() { 28 | text := scanner.Text() 29 | if strings.HasPrefix(text, "--") { 30 | continue 31 | } 32 | if !yield(text, nil) { 33 | _ = f.Close() 34 | return 35 | } 36 | } 37 | _ = f.Close() 38 | } 39 | } 40 | } 41 | 42 | func CopySnapshotData(srcDirname, destDirPath string) error { 43 | files, err := embedFiles.ReadDir(path.Join("snapshots", srcDirname)) 44 | if err != nil { 45 | return fmt.Errorf("read dir: %w", err) 46 | } 47 | for _, file := range files { 48 | srcPath := path.Join("snapshots", srcDirname, file.Name()) 49 | destPath := path.Join(destDirPath, file.Name()) 50 | if file.IsDir() { 51 | continue 52 | } 53 | if err := copyFile(srcPath, destPath); err != nil { 54 | return fmt.Errorf("copy file: %w", err) 55 | } 56 | } 57 | return nil 58 | } 59 | 60 | func copyFile(src, dst string) error { 61 | sourceFile, err := embedFiles.Open(src) 62 | if err != nil { 63 | return err 64 | } 65 | defer sourceFile.Close() 66 | 67 | destinationFile, err := os.Create(dst) 68 | if err != nil { 69 | return err 70 | } 71 | defer destinationFile.Close() 72 | 73 | _, err = io.Copy(destinationFile, sourceFile) 74 | if err != nil { 75 | return err 76 | } 77 | 78 | return nil 79 | } 80 | -------------------------------------------------------------------------------- /internal/testdata/testdata_test.go: -------------------------------------------------------------------------------- 1 | package testdata_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "testing" 8 | 9 | "github.com/abekoh/simple-db/internal/simpledb" 10 | "github.com/abekoh/simple-db/internal/testdata" 11 | "github.com/abekoh/simple-db/internal/transaction" 12 | "github.com/brianvoe/gofakeit/v7" 13 | ) 14 | 15 | func TestCreateTestdata(t *testing.T) { 16 | writer := func(filename string) (write func(string), close func()) { 17 | f, err := os.Create(filename) 18 | if err != nil { 19 | t.Fatal(err) 20 | } 21 | write = func(s string) { 22 | _, err = f.WriteString(s + "\n") 23 | if err != nil { 24 | t.Fatal(err) 25 | } 26 | } 27 | close = func() { 28 | err = f.Close() 29 | if err != nil { 30 | t.Fatal(err) 31 | } 32 | } 33 | return 34 | } 35 | 36 | w, c := writer("create_tables.sql") 37 | defer c() 38 | w("CREATE TABLE departments (department_id INT, department_name VARCHAR(10));") 39 | w("CREATE TABLE students (student_id INT, name VARCHAR(10), major_id INT, grad_year INT);") 40 | w("CREATE TABLE courses(course_id INT, title VARCHAR(20), course_department_id INT);") 41 | w("CREATE TABLE sections(section_id INT, section_course_id INT, professor VARCHAR(8), year_offered int)") 42 | 43 | w, c = writer("create_indexes.sql") 44 | defer c() 45 | w("CREATE INDEX departments_pkey ON departments (department_id);") 46 | w("CREATE INDEX students_pkey ON students (student_id);") 47 | w("CREATE INDEX courses_pkey ON courses (course_id);") 48 | w("CREATE INDEX sections_pkey ON sections (section_id);") 49 | w("CREATE INDEX students_major_department_id ON students (major_department_id);") 50 | w("CREATE INDEX courses_course_department_id ON courses (course_department_id);") 51 | w("CREATE INDEX sections_section_course_id ON sections (section_course_id);") 52 | 53 | faker := gofakeit.New(523207) 54 | 55 | w, c = writer("insert_data.sql") 56 | defer c() 57 | w("-- departments") 58 | const ( 59 | departmentOffset = 100000 60 | departmentLength = 100 61 | ) 62 | for i := 1; i <= departmentLength; i++ { 63 | w(fmt.Sprintf("INSERT INTO departments (department_id, department_name) VALUES (%d, '%s');", departmentOffset+i, faker.Language())) 64 | } 65 | 66 | w("-- students") 67 | const ( 68 | studentOffset = 200000 69 | studentLength = 10000 70 | ) 71 | for i := 1; i <= studentLength; i++ { 72 | w(fmt.Sprintf("INSERT INTO students (student_id, name, major_id, grad_year) VALUES (%d, '%s', %d, %d);", studentOffset+i, faker.FirstName(), faker.Number(departmentOffset, departmentOffset+departmentLength-1), faker.Year())) 73 | } 74 | 75 | w("-- courses") 76 | const ( 77 | courseOffset = 300000 78 | ) 79 | sectionTitles := []string{"Intro 1", "Intro 2", "Intro 3", "Advanced 1", "Advanced 2"} 80 | courseCount := 1 81 | for i := 1; i <= departmentLength; i++ { 82 | sectionN := faker.Number(0, len(sectionTitles)) 83 | for j := 0; j < sectionN; j++ { 84 | w(fmt.Sprintf("INSERT INTO courses (course_id, title, course_department_id) VALUES (%d, '%s', %d);", courseOffset+courseCount, sectionTitles[j], departmentOffset+i)) 85 | courseCount++ 86 | } 87 | } 88 | 89 | w("-- sections") 90 | const ( 91 | sectionOffset = 400000 92 | ) 93 | sectionCount := 1 94 | for i := 1; i <= courseCount; i++ { 95 | w(fmt.Sprintf("INSERT INTO sections (section_id, section_course_id, professor, year_offered) VALUES (%d, %d, '%s', %d);", sectionOffset+sectionCount, courseOffset+i, faker.FirstName(), faker.Year())) 96 | sectionCount++ 97 | } 98 | } 99 | 100 | func TestCreateSnapshots(t *testing.T) { 101 | createSnapshot := func(t *testing.T, dirname string, sqlFilenames ...string) { 102 | t.Helper() 103 | 104 | err := os.RemoveAll(dirname) 105 | if err != nil { 106 | t.Fatal(err) 107 | } 108 | err = os.MkdirAll(dirname, 0755) 109 | if err != nil { 110 | t.Fatal(err) 111 | } 112 | ctx := context.Background() 113 | db, err := simpledb.New(ctx, dirname) 114 | if err != nil { 115 | t.Fatal(err) 116 | } 117 | tx, err := db.NewTx(ctx) 118 | if err != nil { 119 | t.Fatal(err) 120 | } 121 | sqlIter := testdata.SQLIterator(sqlFilenames...) 122 | count := 0 123 | for sql, err := range sqlIter { 124 | t.Logf("execute %s", sql) 125 | if err != nil { 126 | t.Fatal(err) 127 | } 128 | _, err := db.Planner().Execute(sql, tx) 129 | if err != nil { 130 | t.Fatal(err) 131 | } 132 | if count++; count%100 == 0 { 133 | if err := tx.Commit(); err != nil { 134 | t.Fatal(err) 135 | } 136 | tx, err = db.NewTx(ctx) 137 | if err != nil { 138 | t.Fatal(err) 139 | } 140 | } 141 | } 142 | if err := tx.Commit(); err != nil { 143 | t.Fatal(err) 144 | } 145 | } 146 | t.Run("snapshots/tables", func(t *testing.T) { 147 | t.Skip() 148 | transaction.CleanupLockTable(t) 149 | createSnapshot(t, "snapshots/tables", "create_tables.sql") 150 | }) 151 | t.Run("snapshots/tables_data", func(t *testing.T) { 152 | t.Skip() 153 | transaction.CleanupLockTable(t) 154 | createSnapshot(t, "snapshots/tables_data", "create_tables.sql", "insert_data.sql") 155 | }) 156 | t.Run("snapshots/tables_indexes_data", func(t *testing.T) { 157 | t.Skip() 158 | transaction.CleanupLockTable(t) 159 | createSnapshot(t, "snapshots/tables_indexes_data", "create_tables.sql", "create_indexes.sql", "insert_data.sql") 160 | }) 161 | } 162 | -------------------------------------------------------------------------------- /internal/transaction/transaction_test.go: -------------------------------------------------------------------------------- 1 | package transaction 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log/slog" 7 | "reflect" 8 | "strings" 9 | "testing" 10 | "time" 11 | 12 | "golang.org/x/sync/errgroup" 13 | 14 | "github.com/abekoh/simple-db/internal/buffer" 15 | "github.com/abekoh/simple-db/internal/file" 16 | "github.com/abekoh/simple-db/internal/log" 17 | ) 18 | 19 | func TestTransaction(t *testing.T) { 20 | t.Parallel() 21 | t.Run("Transaction", func(t *testing.T) { 22 | CleanupLockTable(t) 23 | must := func(t *testing.T, err error) { 24 | t.Helper() 25 | if err != nil { 26 | t.Fatal(err) 27 | } 28 | } 29 | 30 | slog.SetLogLoggerLevel(slog.LevelDebug) 31 | fm, err := file.NewManager(t.TempDir(), 128) 32 | if err != nil { 33 | t.Fatal(err) 34 | } 35 | lm, err := log.NewManager(fm, "logfile") 36 | if err != nil { 37 | t.Fatal(err) 38 | } 39 | ctx := context.Background() 40 | bm := buffer.NewManager(ctx, fm, lm, 8) 41 | 42 | tx1, err := NewTransaction(ctx, bm, fm, lm) 43 | must(t, err) 44 | blockID := file.NewBlockID("testfile", 1) 45 | _, err = tx1.Pin(blockID) 46 | must(t, err) 47 | must(t, tx1.SetInt32(blockID, 80, 1, false)) 48 | must(t, tx1.SetStr(blockID, 40, "one", false)) 49 | must(t, tx1.Commit()) 50 | 51 | tx2, err := NewTransaction(ctx, bm, fm, lm) 52 | must(t, err) 53 | _, err = tx2.Pin(blockID) 54 | must(t, err) 55 | beforeTx2IntVal, err := tx2.Int32(blockID, 80) 56 | must(t, err) 57 | beforeTx2StrVal, err := tx2.Str(blockID, 40) 58 | must(t, err) 59 | if beforeTx2IntVal != 1 { 60 | t.Errorf("expected 1, got %d", beforeTx2IntVal) 61 | } 62 | if beforeTx2StrVal != "one" { 63 | t.Errorf("expected one, got %s", beforeTx2StrVal) 64 | } 65 | must(t, tx2.SetInt32(blockID, 80, beforeTx2IntVal+1, true)) 66 | must(t, tx2.SetStr(blockID, 40, beforeTx2StrVal+"!", true)) 67 | must(t, tx2.Commit()) 68 | 69 | tx3, err := NewTransaction(ctx, bm, fm, lm) 70 | must(t, err) 71 | _, err = tx3.Pin(blockID) 72 | must(t, err) 73 | beforeTx3IntVal, err := tx3.Int32(blockID, 80) 74 | must(t, err) 75 | beforeTx3StrVal, err := tx3.Str(blockID, 40) 76 | must(t, err) 77 | if beforeTx3IntVal != 2 { 78 | t.Errorf("expected 2, got %d", beforeTx3IntVal) 79 | } 80 | if beforeTx3StrVal != "one!" { 81 | t.Errorf("expected one!, got %s", beforeTx3StrVal) 82 | } 83 | must(t, tx3.SetInt32(blockID, 80, 9999, true)) 84 | must(t, tx3.SetStr(blockID, 40, "two", true)) 85 | must(t, tx3.Rollback()) 86 | 87 | tx4, err := NewTransaction(ctx, bm, fm, lm) 88 | must(t, err) 89 | _, err = tx4.Pin(blockID) 90 | must(t, err) 91 | beforeTx4IntVal, err := tx4.Int32(blockID, 80) 92 | must(t, err) 93 | beforeTx4StrVal, err := tx4.Str(blockID, 40) 94 | must(t, err) 95 | if beforeTx4IntVal != 2 { 96 | t.Errorf("expected 2, got %d", beforeTx4IntVal) 97 | } 98 | if beforeTx4StrVal != "one!" { 99 | t.Errorf("expected one!, got %s", beforeTx4StrVal) 100 | 101 | } 102 | must(t, tx4.Commit()) 103 | 104 | records := make([]string, 0) 105 | for raw := range lm.Iterator() { 106 | r := NewLogRecord(raw) 107 | records = append(records, r.String()) 108 | } 109 | if !reflect.DeepEqual(records, []string{ 110 | "", 111 | "", 112 | "", 113 | "", 114 | "", 115 | "", 116 | "", 117 | "", 118 | "", 119 | "", 120 | "", 121 | "", 122 | }) { 123 | 124 | } 125 | }) 126 | t.Run("Concurrency", func(t *testing.T) { 127 | CleanupLockTable(t) 128 | slog.SetLogLoggerLevel(slog.LevelDebug) 129 | fm, err := file.NewManager(t.TempDir(), 128) 130 | if err != nil { 131 | t.Fatal(err) 132 | } 133 | lm, err := log.NewManager(fm, "logfile") 134 | if err != nil { 135 | t.Fatal(err) 136 | } 137 | ctx := context.Background() 138 | bm := buffer.NewManager(ctx, fm, lm, 8) 139 | 140 | var g errgroup.Group 141 | g.Go(func() error { 142 | txA, err := NewTransaction(ctx, bm, fm, lm) 143 | if err != nil { 144 | return fmt.Errorf("failed txA: %w", err) 145 | } 146 | blockID1 := file.NewBlockID("testfile", 1) 147 | blockID2 := file.NewBlockID("testfile", 2) 148 | _, err = txA.Pin(blockID1) 149 | if err != nil { 150 | return fmt.Errorf("failed txA: %w", err) 151 | } 152 | _, err = txA.Pin(blockID2) 153 | if err != nil { 154 | return fmt.Errorf("failed txA: %w", err) 155 | } 156 | t.Logf("txA: request sLock %s", blockID1) 157 | _, err = txA.Int32(blockID1, 0) 158 | if err != nil { 159 | return fmt.Errorf("failed txA: %w", err) 160 | } 161 | t.Logf("txA: receive sLock %s", blockID1) 162 | time.Sleep(1 * time.Second) 163 | t.Logf("txA: request sLock %s", blockID2) 164 | _, err = txA.Int32(blockID2, 0) 165 | if err != nil { 166 | return fmt.Errorf("failed txA: %w", err) 167 | } 168 | t.Logf("txA: receive sLock %s", blockID2) 169 | if err := txA.Commit(); err != nil { 170 | return fmt.Errorf("failed txA: %w", err) 171 | } 172 | t.Log("txA: commit") 173 | return nil 174 | }) 175 | g.Go(func() error { 176 | txB, err := NewTransaction(ctx, bm, fm, lm) 177 | if err != nil { 178 | return fmt.Errorf("failed txB: %w", err) 179 | } 180 | blockID1 := file.NewBlockID("testfile", 1) 181 | blockID2 := file.NewBlockID("testfile", 2) 182 | _, err = txB.Pin(blockID1) 183 | if err != nil { 184 | return fmt.Errorf("failed txB: %w", err) 185 | } 186 | _, err = txB.Pin(blockID2) 187 | if err != nil { 188 | return fmt.Errorf("failed txB: %w", err) 189 | } 190 | t.Logf("txB: request xLock %s", blockID2) 191 | if err := txB.SetInt32(blockID2, 0, 0, false); err != nil { 192 | return fmt.Errorf("failed txB: %w", err) 193 | } 194 | t.Logf("txB: receive xLock %s", blockID2) 195 | time.Sleep(1 * time.Second) 196 | t.Logf("txB: request sLock %s", blockID1) 197 | _, err = txB.Int32(blockID1, 0) 198 | if err != nil { 199 | return fmt.Errorf("failed txB: %w", err) 200 | } 201 | t.Logf("txB: receive sLock %s", blockID1) 202 | if err := txB.Commit(); err != nil { 203 | return fmt.Errorf("failed txB: %w", err) 204 | } 205 | t.Log("txB: commit") 206 | return nil 207 | }) 208 | g.Go(func() error { 209 | txC, err := NewTransaction(ctx, bm, fm, lm) 210 | if err != nil { 211 | return fmt.Errorf("failed txC: %w", err) 212 | } 213 | blockID1 := file.NewBlockID("testfile", 1) 214 | blockID2 := file.NewBlockID("testfile", 2) 215 | _, err = txC.Pin(blockID1) 216 | if err != nil { 217 | return fmt.Errorf("failed txC: %w", err) 218 | } 219 | _, err = txC.Pin(blockID2) 220 | if err != nil { 221 | return fmt.Errorf("failed txC: %w", err) 222 | } 223 | time.Sleep(500 * time.Millisecond) 224 | t.Logf("txC: request xLock %s", blockID1) 225 | if err := txC.SetInt32(blockID1, 0, 0, false); err != nil { 226 | return fmt.Errorf("failed txC: %w", err) 227 | } 228 | t.Logf("txC: receive xLock %s", blockID1) 229 | time.Sleep(1 * time.Second) 230 | t.Logf("txC: request sLock %s", blockID2) 231 | _, err = txC.Int32(blockID2, 0) 232 | if err != nil { 233 | return fmt.Errorf("failed txC: %w", err) 234 | } 235 | t.Logf("txC: receive sLock %s", blockID2) 236 | if err := txC.Commit(); err != nil { 237 | return fmt.Errorf("failed txC: %w", err) 238 | } 239 | t.Log("txC: commit") 240 | return nil 241 | }) 242 | if err := g.Wait(); err != nil { 243 | t.Fatal(err) 244 | } 245 | }) 246 | t.Run("Concurrency sLock after xLock", func(t *testing.T) { 247 | CleanupLockTable(t) 248 | slog.SetLogLoggerLevel(slog.LevelDebug) 249 | fm, err := file.NewManager(t.TempDir(), 128) 250 | if err != nil { 251 | t.Fatal(err) 252 | } 253 | lm, err := log.NewManager(fm, "logfile") 254 | if err != nil { 255 | t.Fatal(err) 256 | } 257 | ctx := context.Background() 258 | bm := buffer.NewManager(ctx, fm, lm, 8) 259 | 260 | tx, err := NewTransaction(ctx, bm, fm, lm) 261 | if err != nil { 262 | t.Fatal(err) 263 | } 264 | blockID1 := file.NewBlockID("testfile", 1) 265 | _, err = tx.Pin(blockID1) 266 | if err != nil { 267 | t.Fatal(err) 268 | } 269 | err = tx.SetInt32(blockID1, 0, 0, true) 270 | if err != nil { 271 | t.Fatal(err) 272 | } 273 | _, err = tx.Int32(blockID1, 0) 274 | if err != nil { 275 | t.Fatal(err) 276 | } 277 | }) 278 | } 279 | 280 | func TestTransaction_Recover(t *testing.T) { 281 | CleanupLockTable(t) 282 | slog.SetLogLoggerLevel(slog.LevelDebug) 283 | tmpDir := t.TempDir() 284 | blockID0 := file.NewBlockID("testfile", 0) 285 | blockID1 := file.NewBlockID("testfile", 1) 286 | assertValues := func(t *testing.T, expected string, fm *file.Manager) { 287 | p0 := file.NewPage(fm.BlockSize()) 288 | p1 := file.NewPage(fm.BlockSize()) 289 | if err := fm.Read(blockID0, p0); err != nil { 290 | t.Fatal(err) 291 | } 292 | if err := fm.Read(blockID1, p1); err != nil { 293 | t.Fatal(err) 294 | } 295 | pos := int32(0) 296 | var sb strings.Builder 297 | for i := 0; i < 6; i++ { 298 | sb.WriteString(fmt.Sprintf("%d ", p0.Int32(pos))) 299 | sb.WriteString(fmt.Sprintf("%d ", p1.Int32(pos))) 300 | pos += 4 301 | } 302 | sb.WriteString(p0.Str(30) + " ") 303 | sb.WriteString(p1.Str(30) + " ") 304 | if sb.String() != expected { 305 | t.Errorf("expected %s, got %s", expected, sb.String()) 306 | } 307 | } 308 | 309 | fm1, err := file.NewManager(tmpDir, 400) 310 | if err != nil { 311 | t.Fatal(err) 312 | } 313 | lm1, err := log.NewManager(fm1, "logfile") 314 | if err != nil { 315 | t.Fatal(err) 316 | } 317 | ctx1, cancel := context.WithCancel(context.Background()) 318 | bm1 := buffer.NewManager(ctx1, fm1, lm1, 8) 319 | 320 | tx1, err := NewTransaction(ctx1, bm1, fm1, lm1) 321 | if err != nil { 322 | t.Fatal(err) 323 | } 324 | tx2, err := NewTransaction(ctx1, bm1, fm1, lm1) 325 | if err != nil { 326 | t.Fatal(err) 327 | } 328 | 329 | _, err = tx1.Pin(blockID0) 330 | if err != nil { 331 | t.Fatal(err) 332 | } 333 | _, err = tx2.Pin(blockID1) 334 | if err != nil { 335 | t.Fatal(err) 336 | } 337 | 338 | // initialize 339 | pos := int32(0) 340 | for i := 0; i < 6; i++ { 341 | if err := tx1.SetInt32(blockID0, pos, pos, false); err != nil { 342 | t.Fatal(err) 343 | } 344 | if err := tx2.SetInt32(blockID1, pos, pos, false); err != nil { 345 | t.Fatal(err) 346 | } 347 | pos += 4 348 | } 349 | if err := tx1.SetStr(blockID0, 30, "abc", false); err != nil { 350 | t.Fatal(err) 351 | } 352 | if err := tx2.SetStr(blockID1, 30, "def", false); err != nil { 353 | t.Fatal(err) 354 | } 355 | if err := tx1.Commit(); err != nil { 356 | t.Fatal(err) 357 | } 358 | if err := tx2.Commit(); err != nil { 359 | t.Fatal(err) 360 | } 361 | assertValues(t, "0 0 4 4 8 8 12 12 16 16 20 20 abc def ", fm1) 362 | 363 | // modify 364 | tx3, err := NewTransaction(ctx1, bm1, fm1, lm1) 365 | if err != nil { 366 | t.Fatal(err) 367 | } 368 | tx4, err := NewTransaction(ctx1, bm1, fm1, lm1) 369 | if err != nil { 370 | t.Fatal(err) 371 | } 372 | _, err = tx3.Pin(blockID0) 373 | if err != nil { 374 | t.Fatal(err) 375 | } 376 | _, err = tx4.Pin(blockID1) 377 | if err != nil { 378 | t.Fatal(err) 379 | } 380 | pos = int32(0) 381 | for i := 0; i < 6; i++ { 382 | if err := tx3.SetInt32(blockID0, pos, pos+100, true); err != nil { 383 | t.Fatal(err) 384 | } 385 | if err := tx4.SetInt32(blockID1, pos, pos+100, true); err != nil { 386 | t.Fatal(err) 387 | } 388 | pos += 4 389 | } 390 | if err := tx3.SetStr(blockID0, 30, "uvw", true); err != nil { 391 | t.Fatal(err) 392 | } 393 | if err := tx4.SetStr(blockID1, 30, "xyz", true); err != nil { 394 | t.Fatal(err) 395 | } 396 | if err := bm1.FlushAll(3); err != nil { 397 | t.Fatal(err) 398 | } 399 | if err := bm1.FlushAll(4); err != nil { 400 | t.Fatal(err) 401 | } 402 | assertValues(t, "100 100 104 104 108 108 112 112 116 116 120 120 uvw xyz ", fm1) 403 | 404 | // rollback tx3 405 | if err := tx3.Rollback(); err != nil { 406 | t.Fatal(err) 407 | } 408 | assertValues(t, "0 100 4 104 8 108 12 112 16 116 20 120 abc xyz ", fm1) 409 | 410 | cancel() 411 | 412 | fm2, err := file.NewManager(tmpDir, 400) 413 | if err != nil { 414 | t.Fatal(err) 415 | } 416 | lm2, err := log.NewManager(fm2, "logfile") 417 | if err != nil { 418 | t.Fatal(err) 419 | } 420 | ctx2 := context.Background() 421 | bm2 := buffer.NewManager(ctx2, fm2, lm2, 8) 422 | 423 | // recover (rollback tx4) 424 | tx5, err := NewTransaction(ctx2, bm2, fm2, lm2) 425 | if err != nil { 426 | t.Fatal(err) 427 | } 428 | if err := tx5.Recover(); err != nil { 429 | t.Fatal(err) 430 | } 431 | assertValues(t, "0 0 4 4 8 8 12 12 16 16 20 20 abc def ", fm2) 432 | } 433 | --------------------------------------------------------------------------------