├── .github ├── dependabot.yml └── workflows │ ├── actionlint.yml │ ├── test-and-build.yml │ └── two-step-pr-approval.yml ├── .gitignore ├── .go-version ├── .golangci.yml ├── CODEOWNERS ├── LICENSE ├── README.md ├── changes.go ├── filter.go ├── filter_test.go ├── go.mod ├── go.sum ├── index.go ├── index_test.go ├── integ_test.go ├── isolation_test.go ├── memdb.go ├── memdb_test.go ├── schema.go ├── schema_test.go ├── txn.go ├── txn_test.go ├── watch-gen └── main.go ├── watch.go ├── watch_few.go └── watch_test.go /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | updates: 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "daily" 8 | 9 | - package-ecosystem: "gomod" 10 | directory: "/" 11 | schedule: 12 | interval: "weekly" 13 | -------------------------------------------------------------------------------- /.github/workflows/actionlint.yml: -------------------------------------------------------------------------------- 1 | name: Lint GitHub Actions Workflows 2 | on: 3 | push: 4 | paths: 5 | - .github/** 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | actionlint: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@3df4ab11eba7bda6032a0b82a6bb43b11571feac # v4.0.0 15 | - name: "Check workflow files" 16 | uses: docker://docker.mirror.hashicorp.services/rhysd/actionlint:latest 17 | with: 18 | args: -color 19 | -------------------------------------------------------------------------------- /.github/workflows/test-and-build.yml: -------------------------------------------------------------------------------- 1 | name: Test and Build 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | lint: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Checkout Code 15 | uses: actions/checkout@3df4ab11eba7bda6032a0b82a6bb43b11571feac # v4.0.0 16 | - name: Setup Go 17 | uses: actions/setup-go@4d34df0c2316fe8122ab82dc22947d607c0c91f9 # v4.0.0 18 | with: 19 | go-version-file: go.mod 20 | - name: Download go modules 21 | run: go mod download 22 | - name: Check Formatting 23 | run: |- 24 | files=$(go fmt ./...) 25 | if [ -n "$files" ]; then 26 | echo "The following file(s) do not conform to go fmt:" 27 | echo "$files" 28 | exit 1 29 | fi 30 | - name: Vet code 31 | run: go vet ./... 32 | - name: Run golangci-lint 33 | uses: golangci/golangci-lint-action@08e2f20817b15149a52b5b3ebe7de50aff2ba8c5 34 | 35 | go-test: 36 | runs-on: ubuntu-latest 37 | strategy: 38 | matrix: 39 | go-version: 40 | - '1.23' # named in go.mod 41 | - 'oldstable' 42 | - 'stable' 43 | env: 44 | TEST_RESULTS_PATH: '/tmp/test-results' 45 | steps: 46 | - name: Checkout Code 47 | uses: actions/checkout@3df4ab11eba7bda6032a0b82a6bb43b11571feac # v4.0.0 48 | - name: Setup Go 49 | uses: actions/setup-go@4d34df0c2316fe8122ab82dc22947d607c0c91f9 # v4.0.0 50 | with: 51 | go-version: ${{ matrix.go-version }} 52 | - name: Install gotestsum 53 | uses: autero1/action-gotestsum@7263b9d73912eec65f46337689e59fac865c425f # v2.0.0 54 | with: 55 | gotestsum_version: 1.9.0 56 | 57 | - name: Create test directory 58 | run: mkdir -p "$TEST_RESULTS_PATH" 59 | - name: Run go tests 60 | run: | 61 | gotestsum --format=short-verbose --junitfile "$TEST_RESULTS_PATH/gotestsum-report.xml" -- -p 2 -cover -coverprofile=coverage.out ./... 62 | - name: Upload and save artifacts 63 | uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 64 | with: 65 | path: ${{ env.TEST_RESULTS_PATH }} 66 | name: tests-linux-${{matrix.go-version}} 67 | - name: Upload coverage report 68 | uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 69 | with: 70 | path: coverage.out 71 | name: Coverage-report-${{matrix.go-version}} 72 | - name: Display coverage report 73 | run: go tool cover -func=coverage.out 74 | -------------------------------------------------------------------------------- /.github/workflows/two-step-pr-approval.yml: -------------------------------------------------------------------------------- 1 | name: Two-Stage PR Review Process 2 | # We have been instating this process to ensure that any PRs raised by IP compliance team members follows a comprehensive internal team review first and then become open for wider review for shared/core libraries that we co-own. 3 | 4 | on: 5 | pull_request: 6 | types: [opened, synchronize, reopened, labeled, unlabeled, ready_for_review, converted_to_draft] 7 | pull_request_review: 8 | types: [submitted] 9 | 10 | jobs: 11 | manage-pr-status: 12 | runs-on: ubuntu-latest 13 | permissions: 14 | pull-requests: write 15 | contents: write 16 | steps: 17 | - name: Checkout code 18 | uses: actions/checkout@3df4ab11eba7bda6032a0b82a6bb43b11571feac # v4.0.0 19 | 20 | - name: Two stage PR review 21 | uses: hashicorp/two-stage-pr-approval@v0.1.0 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | 26 | .idea 27 | -------------------------------------------------------------------------------- /.go-version: -------------------------------------------------------------------------------- 1 | 1.23 2 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | linters: 2 | disable-all: true 3 | enable: 4 | - errcheck 5 | output_format: colored-line-number 6 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Each line is a file pattern followed by one or more owners. 2 | # More on CODEOWNERS files: https://help.github.com/en/github/creating-cloning-and-archiving-repositories/about-code-owners 3 | 4 | # Default owner 5 | * @hashicorp/team-ip-compliance @hashicorp/raft-force 6 | 7 | # Add override rules below. Each line is a file/folder pattern followed by one or more owners. 8 | # Being an owner means those groups or individuals will be added as reviewers to PRs affecting 9 | # those areas of the code. 10 | # Examples: 11 | # /docs/ @docs-team 12 | # *.js @js-team 13 | # *.go @go-team 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 HashiCorp, Inc. 2 | 3 | Mozilla Public License, version 2.0 4 | 5 | 1. Definitions 6 | 7 | 1.1. "Contributor" 8 | 9 | means each individual or legal entity that creates, contributes to the 10 | creation of, or owns Covered Software. 11 | 12 | 1.2. "Contributor Version" 13 | 14 | means the combination of the Contributions of others (if any) used by a 15 | Contributor and that particular Contributor's Contribution. 16 | 17 | 1.3. "Contribution" 18 | 19 | means Covered Software of a particular Contributor. 20 | 21 | 1.4. "Covered Software" 22 | 23 | means Source Code Form to which the initial Contributor has attached the 24 | notice in Exhibit A, the Executable Form of such Source Code Form, and 25 | Modifications of such Source Code Form, in each case including portions 26 | thereof. 27 | 28 | 1.5. "Incompatible With Secondary Licenses" 29 | means 30 | 31 | a. that the initial Contributor has attached the notice described in 32 | Exhibit B to the Covered Software; or 33 | 34 | b. that the Covered Software was made available under the terms of 35 | version 1.1 or earlier of the License, but not also under the terms of 36 | a Secondary License. 37 | 38 | 1.6. "Executable Form" 39 | 40 | means any form of the work other than Source Code Form. 41 | 42 | 1.7. "Larger Work" 43 | 44 | means a work that combines Covered Software with other material, in a 45 | separate file or files, that is not Covered Software. 46 | 47 | 1.8. "License" 48 | 49 | means this document. 50 | 51 | 1.9. "Licensable" 52 | 53 | means having the right to grant, to the maximum extent possible, whether 54 | at the time of the initial grant or subsequently, any and all of the 55 | rights conveyed by this License. 56 | 57 | 1.10. "Modifications" 58 | 59 | means any of the following: 60 | 61 | a. any file in Source Code Form that results from an addition to, 62 | deletion from, or modification of the contents of Covered Software; or 63 | 64 | b. any new file in Source Code Form that contains any Covered Software. 65 | 66 | 1.11. "Patent Claims" of a Contributor 67 | 68 | means any patent claim(s), including without limitation, method, 69 | process, and apparatus claims, in any patent Licensable by such 70 | Contributor that would be infringed, but for the grant of the License, 71 | by the making, using, selling, offering for sale, having made, import, 72 | or transfer of either its Contributions or its Contributor Version. 73 | 74 | 1.12. "Secondary License" 75 | 76 | means either the GNU General Public License, Version 2.0, the GNU Lesser 77 | General Public License, Version 2.1, the GNU Affero General Public 78 | License, Version 3.0, or any later versions of those licenses. 79 | 80 | 1.13. "Source Code Form" 81 | 82 | means the form of the work preferred for making modifications. 83 | 84 | 1.14. "You" (or "Your") 85 | 86 | means an individual or a legal entity exercising rights under this 87 | License. For legal entities, "You" includes any entity that controls, is 88 | controlled by, or is under common control with You. For purposes of this 89 | definition, "control" means (a) the power, direct or indirect, to cause 90 | the direction or management of such entity, whether by contract or 91 | otherwise, or (b) ownership of more than fifty percent (50%) of the 92 | outstanding shares or beneficial ownership of such entity. 93 | 94 | 95 | 2. License Grants and Conditions 96 | 97 | 2.1. Grants 98 | 99 | Each Contributor hereby grants You a world-wide, royalty-free, 100 | non-exclusive license: 101 | 102 | a. under intellectual property rights (other than patent or trademark) 103 | Licensable by such Contributor to use, reproduce, make available, 104 | modify, display, perform, distribute, and otherwise exploit its 105 | Contributions, either on an unmodified basis, with Modifications, or 106 | as part of a Larger Work; and 107 | 108 | b. under Patent Claims of such Contributor to make, use, sell, offer for 109 | sale, have made, import, and otherwise transfer either its 110 | Contributions or its Contributor Version. 111 | 112 | 2.2. Effective Date 113 | 114 | The licenses granted in Section 2.1 with respect to any Contribution 115 | become effective for each Contribution on the date the Contributor first 116 | distributes such Contribution. 117 | 118 | 2.3. Limitations on Grant Scope 119 | 120 | The licenses granted in this Section 2 are the only rights granted under 121 | this License. No additional rights or licenses will be implied from the 122 | distribution or licensing of Covered Software under this License. 123 | Notwithstanding Section 2.1(b) above, no patent license is granted by a 124 | Contributor: 125 | 126 | a. for any code that a Contributor has removed from Covered Software; or 127 | 128 | b. for infringements caused by: (i) Your and any other third party's 129 | modifications of Covered Software, or (ii) the combination of its 130 | Contributions with other software (except as part of its Contributor 131 | Version); or 132 | 133 | c. under Patent Claims infringed by Covered Software in the absence of 134 | its Contributions. 135 | 136 | This License does not grant any rights in the trademarks, service marks, 137 | or logos of any Contributor (except as may be necessary to comply with 138 | the notice requirements in Section 3.4). 139 | 140 | 2.4. Subsequent Licenses 141 | 142 | No Contributor makes additional grants as a result of Your choice to 143 | distribute the Covered Software under a subsequent version of this 144 | License (see Section 10.2) or under the terms of a Secondary License (if 145 | permitted under the terms of Section 3.3). 146 | 147 | 2.5. Representation 148 | 149 | Each Contributor represents that the Contributor believes its 150 | Contributions are its original creation(s) or it has sufficient rights to 151 | grant the rights to its Contributions conveyed by this License. 152 | 153 | 2.6. Fair Use 154 | 155 | This License is not intended to limit any rights You have under 156 | applicable copyright doctrines of fair use, fair dealing, or other 157 | equivalents. 158 | 159 | 2.7. Conditions 160 | 161 | Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in 162 | Section 2.1. 163 | 164 | 165 | 3. Responsibilities 166 | 167 | 3.1. Distribution of Source Form 168 | 169 | All distribution of Covered Software in Source Code Form, including any 170 | Modifications that You create or to which You contribute, must be under 171 | the terms of this License. You must inform recipients that the Source 172 | Code Form of the Covered Software is governed by the terms of this 173 | License, and how they can obtain a copy of this License. You may not 174 | attempt to alter or restrict the recipients' rights in the Source Code 175 | Form. 176 | 177 | 3.2. Distribution of Executable Form 178 | 179 | If You distribute Covered Software in Executable Form then: 180 | 181 | a. such Covered Software must also be made available in Source Code Form, 182 | as described in Section 3.1, and You must inform recipients of the 183 | Executable Form how they can obtain a copy of such Source Code Form by 184 | reasonable means in a timely manner, at a charge no more than the cost 185 | of distribution to the recipient; and 186 | 187 | b. You may distribute such Executable Form under the terms of this 188 | License, or sublicense it under different terms, provided that the 189 | license for the Executable Form does not attempt to limit or alter the 190 | recipients' rights in the Source Code Form under this License. 191 | 192 | 3.3. Distribution of a Larger Work 193 | 194 | You may create and distribute a Larger Work under terms of Your choice, 195 | provided that You also comply with the requirements of this License for 196 | the Covered Software. If the Larger Work is a combination of Covered 197 | Software with a work governed by one or more Secondary Licenses, and the 198 | Covered Software is not Incompatible With Secondary Licenses, this 199 | License permits You to additionally distribute such Covered Software 200 | under the terms of such Secondary License(s), so that the recipient of 201 | the Larger Work may, at their option, further distribute the Covered 202 | Software under the terms of either this License or such Secondary 203 | License(s). 204 | 205 | 3.4. Notices 206 | 207 | You may not remove or alter the substance of any license notices 208 | (including copyright notices, patent notices, disclaimers of warranty, or 209 | limitations of liability) contained within the Source Code Form of the 210 | Covered Software, except that You may alter any license notices to the 211 | extent required to remedy known factual inaccuracies. 212 | 213 | 3.5. Application of Additional Terms 214 | 215 | You may choose to offer, and to charge a fee for, warranty, support, 216 | indemnity or liability obligations to one or more recipients of Covered 217 | Software. However, You may do so only on Your own behalf, and not on 218 | behalf of any Contributor. You must make it absolutely clear that any 219 | such warranty, support, indemnity, or liability obligation is offered by 220 | You alone, and You hereby agree to indemnify every Contributor for any 221 | liability incurred by such Contributor as a result of warranty, support, 222 | indemnity or liability terms You offer. You may include additional 223 | disclaimers of warranty and limitations of liability specific to any 224 | jurisdiction. 225 | 226 | 4. Inability to Comply Due to Statute or Regulation 227 | 228 | If it is impossible for You to comply with any of the terms of this License 229 | with respect to some or all of the Covered Software due to statute, 230 | judicial order, or regulation then You must: (a) comply with the terms of 231 | this License to the maximum extent possible; and (b) describe the 232 | limitations and the code they affect. Such description must be placed in a 233 | text file included with all distributions of the Covered Software under 234 | this License. Except to the extent prohibited by statute or regulation, 235 | such description must be sufficiently detailed for a recipient of ordinary 236 | skill to be able to understand it. 237 | 238 | 5. Termination 239 | 240 | 5.1. The rights granted under this License will terminate automatically if You 241 | fail to comply with any of its terms. However, if You become compliant, 242 | then the rights granted under this License from a particular Contributor 243 | are reinstated (a) provisionally, unless and until such Contributor 244 | explicitly and finally terminates Your grants, and (b) on an ongoing 245 | basis, if such Contributor fails to notify You of the non-compliance by 246 | some reasonable means prior to 60 days after You have come back into 247 | compliance. Moreover, Your grants from a particular Contributor are 248 | reinstated on an ongoing basis if such Contributor notifies You of the 249 | non-compliance by some reasonable means, this is the first time You have 250 | received notice of non-compliance with this License from such 251 | Contributor, and You become compliant prior to 30 days after Your receipt 252 | of the notice. 253 | 254 | 5.2. If You initiate litigation against any entity by asserting a patent 255 | infringement claim (excluding declaratory judgment actions, 256 | counter-claims, and cross-claims) alleging that a Contributor Version 257 | directly or indirectly infringes any patent, then the rights granted to 258 | You by any and all Contributors for the Covered Software under Section 259 | 2.1 of this License shall terminate. 260 | 261 | 5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user 262 | license agreements (excluding distributors and resellers) which have been 263 | validly granted by You or Your distributors under this License prior to 264 | termination shall survive termination. 265 | 266 | 6. Disclaimer of Warranty 267 | 268 | Covered Software is provided under this License on an "as is" basis, 269 | without warranty of any kind, either expressed, implied, or statutory, 270 | including, without limitation, warranties that the Covered Software is free 271 | of defects, merchantable, fit for a particular purpose or non-infringing. 272 | The entire risk as to the quality and performance of the Covered Software 273 | is with You. Should any Covered Software prove defective in any respect, 274 | You (not any Contributor) assume the cost of any necessary servicing, 275 | repair, or correction. This disclaimer of warranty constitutes an essential 276 | part of this License. No use of any Covered Software is authorized under 277 | this License except under this disclaimer. 278 | 279 | 7. Limitation of Liability 280 | 281 | Under no circumstances and under no legal theory, whether tort (including 282 | negligence), contract, or otherwise, shall any Contributor, or anyone who 283 | distributes Covered Software as permitted above, be liable to You for any 284 | direct, indirect, special, incidental, or consequential damages of any 285 | character including, without limitation, damages for lost profits, loss of 286 | goodwill, work stoppage, computer failure or malfunction, or any and all 287 | other commercial damages or losses, even if such party shall have been 288 | informed of the possibility of such damages. This limitation of liability 289 | shall not apply to liability for death or personal injury resulting from 290 | such party's negligence to the extent applicable law prohibits such 291 | limitation. Some jurisdictions do not allow the exclusion or limitation of 292 | incidental or consequential damages, so this exclusion and limitation may 293 | not apply to You. 294 | 295 | 8. Litigation 296 | 297 | Any litigation relating to this License may be brought only in the courts 298 | of a jurisdiction where the defendant maintains its principal place of 299 | business and such litigation shall be governed by laws of that 300 | jurisdiction, without reference to its conflict-of-law provisions. Nothing 301 | in this Section shall prevent a party's ability to bring cross-claims or 302 | counter-claims. 303 | 304 | 9. Miscellaneous 305 | 306 | This License represents the complete agreement concerning the subject 307 | matter hereof. If any provision of this License is held to be 308 | unenforceable, such provision shall be reformed only to the extent 309 | necessary to make it enforceable. Any law or regulation which provides that 310 | the language of a contract shall be construed against the drafter shall not 311 | be used to construe this License against a Contributor. 312 | 313 | 314 | 10. Versions of the License 315 | 316 | 10.1. New Versions 317 | 318 | Mozilla Foundation is the license steward. Except as provided in Section 319 | 10.3, no one other than the license steward has the right to modify or 320 | publish new versions of this License. Each version will be given a 321 | distinguishing version number. 322 | 323 | 10.2. Effect of New Versions 324 | 325 | You may distribute the Covered Software under the terms of the version 326 | of the License under which You originally received the Covered Software, 327 | or under the terms of any subsequent version published by the license 328 | steward. 329 | 330 | 10.3. Modified Versions 331 | 332 | If you create software not governed by this License, and you want to 333 | create a new license for such software, you may create and use a 334 | modified version of this License if you rename the license and remove 335 | any references to the name of the license steward (except to note that 336 | such modified license differs from this License). 337 | 338 | 10.4. Distributing Source Code Form that is Incompatible With Secondary 339 | Licenses If You choose to distribute Source Code Form that is 340 | Incompatible With Secondary Licenses under the terms of this version of 341 | the License, the notice described in Exhibit B of this License must be 342 | attached. 343 | 344 | Exhibit A - Source Code Form License Notice 345 | 346 | This Source Code Form is subject to the 347 | terms of the Mozilla Public License, v. 348 | 2.0. If a copy of the MPL was not 349 | distributed with this file, You can 350 | obtain one at 351 | http://mozilla.org/MPL/2.0/. 352 | 353 | If it is not possible or desirable to put the notice in a particular file, 354 | then You may include the notice in a location (such as a LICENSE file in a 355 | relevant directory) where a recipient would be likely to look for such a 356 | notice. 357 | 358 | You may add additional accurate notices of copyright ownership. 359 | 360 | Exhibit B - "Incompatible With Secondary Licenses" Notice 361 | 362 | This Source Code Form is "Incompatible 363 | With Secondary Licenses", as defined by 364 | the Mozilla Public License, v. 2.0. 365 | 366 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-memdb [![CircleCI](https://circleci.com/gh/hashicorp/go-memdb/tree/master.svg?style=svg)](https://circleci.com/gh/hashicorp/go-memdb/tree/master) 2 | 3 | Provides the `memdb` package that implements a simple in-memory database 4 | built on immutable radix trees. The database provides Atomicity, Consistency 5 | and Isolation from ACID. Being that it is in-memory, it does not provide durability. 6 | The database is instantiated with a schema that specifies the tables and indices 7 | that exist and allows transactions to be executed. 8 | 9 | The database provides the following: 10 | 11 | * Multi-Version Concurrency Control (MVCC) - By leveraging immutable radix trees 12 | the database is able to support any number of concurrent readers without locking, 13 | and allows a writer to make progress. 14 | 15 | * Transaction Support - The database allows for rich transactions, in which multiple 16 | objects are inserted, updated or deleted. The transactions can span multiple tables, 17 | and are applied atomically. The database provides atomicity and isolation in ACID 18 | terminology, such that until commit the updates are not visible. 19 | 20 | * Rich Indexing - Tables can support any number of indexes, which can be simple like 21 | a single field index, or more advanced compound field indexes. Certain types like 22 | UUID can be efficiently compressed from strings into byte indexes for reduced 23 | storage requirements. 24 | 25 | * Watches - Callers can populate a watch set as part of a query, which can be used to 26 | detect when a modification has been made to the database which affects the query 27 | results. This lets callers easily watch for changes in the database in a very general 28 | way. 29 | 30 | For the underlying immutable radix trees, see [go-immutable-radix](https://github.com/hashicorp/go-immutable-radix). 31 | 32 | Documentation 33 | ============= 34 | 35 | The full documentation is available on [Godoc](https://pkg.go.dev/github.com/hashicorp/go-memdb). 36 | 37 | Example 38 | ======= 39 | 40 | Below is a [simple example](https://play.golang.org/p/gCGE9FA4og1) of usage 41 | 42 | ```go 43 | // Create a sample struct 44 | type Person struct { 45 | Email string 46 | Name string 47 | Age int 48 | } 49 | 50 | // Create the DB schema 51 | schema := &memdb.DBSchema{ 52 | Tables: map[string]*memdb.TableSchema{ 53 | "person": &memdb.TableSchema{ 54 | Name: "person", 55 | Indexes: map[string]*memdb.IndexSchema{ 56 | "id": &memdb.IndexSchema{ 57 | Name: "id", 58 | Unique: true, 59 | Indexer: &memdb.StringFieldIndex{Field: "Email"}, 60 | }, 61 | "age": &memdb.IndexSchema{ 62 | Name: "age", 63 | Unique: false, 64 | Indexer: &memdb.IntFieldIndex{Field: "Age"}, 65 | }, 66 | }, 67 | }, 68 | }, 69 | } 70 | 71 | // Create a new data base 72 | db, err := memdb.NewMemDB(schema) 73 | if err != nil { 74 | panic(err) 75 | } 76 | 77 | // Create a write transaction 78 | txn := db.Txn(true) 79 | 80 | // Insert some people 81 | people := []*Person{ 82 | &Person{"joe@aol.com", "Joe", 30}, 83 | &Person{"lucy@aol.com", "Lucy", 35}, 84 | &Person{"tariq@aol.com", "Tariq", 21}, 85 | &Person{"dorothy@aol.com", "Dorothy", 53}, 86 | } 87 | for _, p := range people { 88 | if err := txn.Insert("person", p); err != nil { 89 | panic(err) 90 | } 91 | } 92 | 93 | // Commit the transaction 94 | txn.Commit() 95 | 96 | // Create read-only transaction 97 | txn = db.Txn(false) 98 | defer txn.Abort() 99 | 100 | // Lookup by email 101 | raw, err := txn.First("person", "id", "joe@aol.com") 102 | if err != nil { 103 | panic(err) 104 | } 105 | 106 | // Say hi! 107 | fmt.Printf("Hello %s!\n", raw.(*Person).Name) 108 | 109 | // List all the people 110 | it, err := txn.Get("person", "id") 111 | if err != nil { 112 | panic(err) 113 | } 114 | 115 | fmt.Println("All the people:") 116 | for obj := it.Next(); obj != nil; obj = it.Next() { 117 | p := obj.(*Person) 118 | fmt.Printf(" %s\n", p.Name) 119 | } 120 | 121 | // Range scan over people with ages between 25 and 35 inclusive 122 | it, err = txn.LowerBound("person", "age", 25) 123 | if err != nil { 124 | panic(err) 125 | } 126 | 127 | fmt.Println("People aged 25 - 35:") 128 | for obj := it.Next(); obj != nil; obj = it.Next() { 129 | p := obj.(*Person) 130 | if p.Age > 35 { 131 | break 132 | } 133 | fmt.Printf(" %s is aged %d\n", p.Name, p.Age) 134 | } 135 | // Output: 136 | // Hello Joe! 137 | // All the people: 138 | // Dorothy 139 | // Joe 140 | // Lucy 141 | // Tariq 142 | // People aged 25 - 35: 143 | // Joe is aged 30 144 | // Lucy is aged 35 145 | ``` 146 | 147 | -------------------------------------------------------------------------------- /changes.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package memdb 5 | 6 | // Changes describes a set of mutations to memDB tables performed during a 7 | // transaction. 8 | type Changes []Change 9 | 10 | // Change describes a mutation to an object in a table. 11 | type Change struct { 12 | Table string 13 | Before interface{} 14 | After interface{} 15 | 16 | // primaryKey stores the raw key value from the primary index so that we can 17 | // de-duplicate multiple updates of the same object in the same transaction 18 | // but we don't expose this implementation detail to the consumer. 19 | primaryKey []byte 20 | } 21 | 22 | // Created returns true if the mutation describes a new object being inserted. 23 | func (m *Change) Created() bool { 24 | return m.Before == nil && m.After != nil 25 | } 26 | 27 | // Updated returns true if the mutation describes an existing object being 28 | // updated. 29 | func (m *Change) Updated() bool { 30 | return m.Before != nil && m.After != nil 31 | } 32 | 33 | // Deleted returns true if the mutation describes an existing object being 34 | // deleted. 35 | func (m *Change) Deleted() bool { 36 | return m.Before != nil && m.After == nil 37 | } 38 | -------------------------------------------------------------------------------- /filter.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package memdb 5 | 6 | // FilterFunc is a function that takes the results of an iterator and returns 7 | // whether the result should be filtered out. 8 | type FilterFunc func(interface{}) bool 9 | 10 | // FilterIterator is used to wrap a ResultIterator and apply a filter over it. 11 | type FilterIterator struct { 12 | // filter is the filter function applied over the base iterator. 13 | filter FilterFunc 14 | 15 | // iter is the iterator that is being wrapped. 16 | iter ResultIterator 17 | } 18 | 19 | // NewFilterIterator wraps a ResultIterator. The filter function is applied 20 | // to each value returned by a call to iter.Next. 21 | // 22 | // See the documentation for ResultIterator to understand the behaviour of the 23 | // returned FilterIterator. 24 | func NewFilterIterator(iter ResultIterator, filter FilterFunc) *FilterIterator { 25 | return &FilterIterator{ 26 | filter: filter, 27 | iter: iter, 28 | } 29 | } 30 | 31 | // WatchCh returns the watch channel of the wrapped iterator. 32 | func (f *FilterIterator) WatchCh() <-chan struct{} { return f.iter.WatchCh() } 33 | 34 | // Next returns the next non-filtered result from the wrapped iterator. 35 | func (f *FilterIterator) Next() interface{} { 36 | for { 37 | if value := f.iter.Next(); value == nil || !f.filter(value) { 38 | return value 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /filter_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package memdb 5 | 6 | import "testing" 7 | 8 | // Test that the iterator meets the required interface 9 | func TestFilterIterator_Interface(t *testing.T) { 10 | var _ ResultIterator = &FilterIterator{} 11 | } 12 | 13 | func TestFilterIterator(t *testing.T) { 14 | db := testDB(t) 15 | txn := db.Txn(true) 16 | 17 | obj1 := &TestObject{ 18 | ID: "a", 19 | Foo: "xyz", 20 | Qux: []string{"xyz1"}, 21 | } 22 | obj2 := &TestObject{ 23 | ID: "medium-length", 24 | Foo: "xyz", 25 | Qux: []string{"xyz1", "xyz2"}, 26 | } 27 | obj3 := &TestObject{ 28 | ID: "super-long-unique-identifier", 29 | Foo: "xyz", 30 | Qux: []string{"xyz1", "xyz2"}, 31 | } 32 | 33 | err := txn.Insert("main", obj1) 34 | if err != nil { 35 | t.Fatalf("err: %v", err) 36 | } 37 | err = txn.Insert("main", obj2) 38 | if err != nil { 39 | t.Fatalf("err: %v", err) 40 | } 41 | err = txn.Insert("main", obj3) 42 | if err != nil { 43 | t.Fatalf("err: %v", err) 44 | } 45 | 46 | filterFactory := func(idLengthLimit int) func(interface{}) bool { 47 | limit := idLengthLimit 48 | return func(raw interface{}) bool { 49 | obj, ok := raw.(*TestObject) 50 | if !ok { 51 | return true 52 | } 53 | 54 | return len(obj.ID) > limit 55 | } 56 | } 57 | 58 | checkResult := func(txn *Txn) { 59 | // Attempt a row scan on the ID 60 | result, err := txn.Get("main", "id") 61 | if err != nil { 62 | t.Fatalf("err: %v", err) 63 | } 64 | 65 | // Wrap the iterator and try various filters 66 | filter := NewFilterIterator(result, filterFactory(6)) 67 | if raw := filter.Next(); raw != obj1 { 68 | t.Fatalf("bad: %#v %#v", raw, obj1) 69 | } 70 | 71 | if raw := filter.Next(); raw != nil { 72 | t.Fatalf("bad: %#v %#v", raw, nil) 73 | } 74 | 75 | result, err = txn.Get("main", "id") 76 | if err != nil { 77 | t.Fatalf("err: %v", err) 78 | } 79 | 80 | filter = NewFilterIterator(result, filterFactory(15)) 81 | if raw := filter.Next(); raw != obj1 { 82 | t.Fatalf("bad: %#v %#v", raw, obj1) 83 | } 84 | 85 | if raw := filter.Next(); raw != obj2 { 86 | t.Fatalf("bad: %#v %#v", raw, obj2) 87 | } 88 | 89 | if raw := filter.Next(); raw != nil { 90 | t.Fatalf("bad: %#v %#v", raw, nil) 91 | } 92 | 93 | result, err = txn.Get("main", "id") 94 | if err != nil { 95 | t.Fatalf("err: %v", err) 96 | } 97 | 98 | filter = NewFilterIterator(result, filterFactory(150)) 99 | if raw := filter.Next(); raw != obj1 { 100 | t.Fatalf("bad: %#v %#v", raw, obj1) 101 | } 102 | 103 | if raw := filter.Next(); raw != obj2 { 104 | t.Fatalf("bad: %#v %#v", raw, obj2) 105 | } 106 | 107 | if raw := filter.Next(); raw != obj3 { 108 | t.Fatalf("bad: %#v %#v", raw, obj3) 109 | } 110 | 111 | if raw := filter.Next(); raw != nil { 112 | t.Fatalf("bad: %#v %#v", raw, nil) 113 | } 114 | } 115 | 116 | // Check the results within the txn 117 | checkResult(txn) 118 | 119 | // Commit and start a new read transaction 120 | txn.Commit() 121 | txn = db.Txn(false) 122 | 123 | // Check the results in a new txn 124 | checkResult(txn) 125 | } 126 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/go-memdb 2 | 3 | go 1.23 4 | 5 | require github.com/hashicorp/go-immutable-radix v1.3.1 6 | 7 | require github.com/hashicorp/golang-lru v0.5.4 // indirect 8 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/hashicorp/go-immutable-radix v1.3.1 h1:DKHmCUm2hRBK510BaiZlwvpD40f8bJFeZnpfm2KLowc= 2 | github.com/hashicorp/go-immutable-radix v1.3.1/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= 3 | github.com/hashicorp/go-uuid v1.0.0 h1:RS8zrF7PhGwyNPOtxSClXXj9HA8feRnJzgnI1RJCSnM= 4 | github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= 5 | github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= 6 | github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= 7 | github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= 8 | -------------------------------------------------------------------------------- /index.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package memdb 5 | 6 | import ( 7 | "encoding/binary" 8 | "encoding/hex" 9 | "errors" 10 | "fmt" 11 | "reflect" 12 | "strconv" 13 | "strings" 14 | ) 15 | 16 | // Indexer is an interface used for defining indexes. Indexes are used 17 | // for efficient lookup of objects in a MemDB table. An Indexer must also 18 | // implement one of SingleIndexer or MultiIndexer. 19 | // 20 | // Indexers are primarily responsible for returning the lookup key as 21 | // a byte slice. The byte slice is the key data in the underlying data storage. 22 | type Indexer interface { 23 | // FromArgs is called to build the exact index key from a list of arguments. 24 | FromArgs(args ...interface{}) ([]byte, error) 25 | } 26 | 27 | // SingleIndexer is an interface used for defining indexes that generate a 28 | // single value per object 29 | type SingleIndexer interface { 30 | // FromObject extracts the index value from an object. The return values 31 | // are whether the index value was found, the index value, and any error 32 | // while extracting the index value, respectively. 33 | FromObject(raw interface{}) (bool, []byte, error) 34 | } 35 | 36 | // MultiIndexer is an interface used for defining indexes that generate 37 | // multiple values per object. Each value is stored as a seperate index 38 | // pointing to the same object. 39 | // 40 | // For example, an index that extracts the first and last name of a person 41 | // and allows lookup based on eitherd would be a MultiIndexer. The FromObject 42 | // of this example would split the first and last name and return both as 43 | // values. 44 | type MultiIndexer interface { 45 | // FromObject extracts index values from an object. The return values 46 | // are the same as a SingleIndexer except there can be multiple index 47 | // values. 48 | FromObject(raw interface{}) (bool, [][]byte, error) 49 | } 50 | 51 | // PrefixIndexer is an optional interface on top of an Indexer that allows 52 | // indexes to support prefix-based iteration. 53 | type PrefixIndexer interface { 54 | // PrefixFromArgs is the same as FromArgs for an Indexer except that 55 | // the index value returned should return all prefix-matched values. 56 | PrefixFromArgs(args ...interface{}) ([]byte, error) 57 | } 58 | 59 | // StringFieldIndex is used to extract a field from an object 60 | // using reflection and builds an index on that field. 61 | type StringFieldIndex struct { 62 | Field string 63 | Lowercase bool 64 | } 65 | 66 | func (s *StringFieldIndex) FromObject(obj interface{}) (bool, []byte, error) { 67 | v := reflect.ValueOf(obj) 68 | v = reflect.Indirect(v) // Dereference the pointer if any 69 | 70 | fv := v.FieldByName(s.Field) 71 | isPtr := fv.Kind() == reflect.Ptr 72 | fv = reflect.Indirect(fv) 73 | if !isPtr && !fv.IsValid() { 74 | return false, nil, 75 | fmt.Errorf("field '%s' for %#v is invalid %v ", s.Field, obj, isPtr) 76 | } 77 | 78 | if isPtr && !fv.IsValid() { 79 | val := "" 80 | return false, []byte(val), nil 81 | } 82 | 83 | val := fv.String() 84 | if val == "" { 85 | return false, nil, nil 86 | } 87 | 88 | if s.Lowercase { 89 | val = strings.ToLower(val) 90 | } 91 | 92 | // Add the null character as a terminator 93 | val += "\x00" 94 | return true, []byte(val), nil 95 | } 96 | 97 | func (s *StringFieldIndex) FromArgs(args ...interface{}) ([]byte, error) { 98 | if len(args) != 1 { 99 | return nil, fmt.Errorf("must provide only a single argument") 100 | } 101 | arg, ok := args[0].(string) 102 | if !ok { 103 | return nil, fmt.Errorf("argument must be a string: %#v", args[0]) 104 | } 105 | if s.Lowercase { 106 | arg = strings.ToLower(arg) 107 | } 108 | // Add the null character as a terminator 109 | arg += "\x00" 110 | return []byte(arg), nil 111 | } 112 | 113 | func (s *StringFieldIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) { 114 | val, err := s.FromArgs(args...) 115 | if err != nil { 116 | return nil, err 117 | } 118 | 119 | // Strip the null terminator, the rest is a prefix 120 | n := len(val) 121 | if n > 0 { 122 | return val[:n-1], nil 123 | } 124 | return val, nil 125 | } 126 | 127 | // StringSliceFieldIndex builds an index from a field on an object that is a 128 | // string slice ([]string). Each value within the string slice can be used for 129 | // lookup. 130 | type StringSliceFieldIndex struct { 131 | Field string 132 | Lowercase bool 133 | } 134 | 135 | func (s *StringSliceFieldIndex) FromObject(obj interface{}) (bool, [][]byte, error) { 136 | v := reflect.ValueOf(obj) 137 | v = reflect.Indirect(v) // Dereference the pointer if any 138 | 139 | fv := v.FieldByName(s.Field) 140 | if !fv.IsValid() { 141 | return false, nil, 142 | fmt.Errorf("field '%s' for %#v is invalid", s.Field, obj) 143 | } 144 | 145 | if fv.Kind() != reflect.Slice || fv.Type().Elem().Kind() != reflect.String { 146 | return false, nil, fmt.Errorf("field '%s' is not a string slice", s.Field) 147 | } 148 | 149 | length := fv.Len() 150 | vals := make([][]byte, 0, length) 151 | for i := 0; i < fv.Len(); i++ { 152 | val := fv.Index(i).String() 153 | if val == "" { 154 | continue 155 | } 156 | 157 | if s.Lowercase { 158 | val = strings.ToLower(val) 159 | } 160 | 161 | // Add the null character as a terminator 162 | val += "\x00" 163 | vals = append(vals, []byte(val)) 164 | } 165 | if len(vals) == 0 { 166 | return false, nil, nil 167 | } 168 | return true, vals, nil 169 | } 170 | 171 | func (s *StringSliceFieldIndex) FromArgs(args ...interface{}) ([]byte, error) { 172 | if len(args) != 1 { 173 | return nil, fmt.Errorf("must provide only a single argument") 174 | } 175 | arg, ok := args[0].(string) 176 | if !ok { 177 | return nil, fmt.Errorf("argument must be a string: %#v", args[0]) 178 | } 179 | if s.Lowercase { 180 | arg = strings.ToLower(arg) 181 | } 182 | // Add the null character as a terminator 183 | arg += "\x00" 184 | return []byte(arg), nil 185 | } 186 | 187 | func (s *StringSliceFieldIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) { 188 | val, err := s.FromArgs(args...) 189 | if err != nil { 190 | return nil, err 191 | } 192 | 193 | // Strip the null terminator, the rest is a prefix 194 | n := len(val) 195 | if n > 0 { 196 | return val[:n-1], nil 197 | } 198 | return val, nil 199 | } 200 | 201 | // StringMapFieldIndex is used to extract a field of type map[string]string 202 | // from an object using reflection and builds an index on that field. 203 | // 204 | // Note that although FromArgs in theory supports using either one or 205 | // two arguments, there is a bug: FromObject only creates an index 206 | // using key/value, and does not also create an index using key. This 207 | // means a lookup using one argument will never actually work. 208 | // 209 | // It is currently left as-is to prevent backwards compatibility 210 | // issues. 211 | // 212 | // TODO: Fix this in the next major bump. 213 | type StringMapFieldIndex struct { 214 | Field string 215 | Lowercase bool 216 | } 217 | 218 | var MapType = reflect.MapOf(reflect.TypeOf(""), reflect.TypeOf("")).Kind() 219 | 220 | func (s *StringMapFieldIndex) FromObject(obj interface{}) (bool, [][]byte, error) { 221 | v := reflect.ValueOf(obj) 222 | v = reflect.Indirect(v) // Dereference the pointer if any 223 | 224 | fv := v.FieldByName(s.Field) 225 | if !fv.IsValid() { 226 | return false, nil, fmt.Errorf("field '%s' for %#v is invalid", s.Field, obj) 227 | } 228 | 229 | if fv.Kind() != MapType { 230 | return false, nil, fmt.Errorf("field '%s' is not a map[string]string", s.Field) 231 | } 232 | 233 | length := fv.Len() 234 | vals := make([][]byte, 0, length) 235 | for _, key := range fv.MapKeys() { 236 | k := key.String() 237 | if k == "" { 238 | continue 239 | } 240 | val := fv.MapIndex(key).String() 241 | 242 | if s.Lowercase { 243 | k = strings.ToLower(k) 244 | val = strings.ToLower(val) 245 | } 246 | 247 | // Add the null character as a terminator 248 | k += "\x00" + val + "\x00" 249 | 250 | vals = append(vals, []byte(k)) 251 | } 252 | if len(vals) == 0 { 253 | return false, nil, nil 254 | } 255 | return true, vals, nil 256 | } 257 | 258 | // WARNING: Because of a bug in FromObject, this function will never return 259 | // a value when using the single-argument version. 260 | func (s *StringMapFieldIndex) FromArgs(args ...interface{}) ([]byte, error) { 261 | if len(args) > 2 || len(args) == 0 { 262 | return nil, fmt.Errorf("must provide one or two arguments") 263 | } 264 | key, ok := args[0].(string) 265 | if !ok { 266 | return nil, fmt.Errorf("argument must be a string: %#v", args[0]) 267 | } 268 | if s.Lowercase { 269 | key = strings.ToLower(key) 270 | } 271 | // Add the null character as a terminator 272 | key += "\x00" 273 | 274 | if len(args) == 2 { 275 | val, ok := args[1].(string) 276 | if !ok { 277 | return nil, fmt.Errorf("argument must be a string: %#v", args[1]) 278 | } 279 | if s.Lowercase { 280 | val = strings.ToLower(val) 281 | } 282 | // Add the null character as a terminator 283 | key += val + "\x00" 284 | } 285 | 286 | return []byte(key), nil 287 | } 288 | 289 | // IntFieldIndex is used to extract an int field from an object using 290 | // reflection and builds an index on that field. 291 | type IntFieldIndex struct { 292 | Field string 293 | } 294 | 295 | func (i *IntFieldIndex) FromObject(obj interface{}) (bool, []byte, error) { 296 | v := reflect.ValueOf(obj) 297 | v = reflect.Indirect(v) // Dereference the pointer if any 298 | 299 | fv := v.FieldByName(i.Field) 300 | if !fv.IsValid() { 301 | return false, nil, 302 | fmt.Errorf("field '%s' for %#v is invalid", i.Field, obj) 303 | } 304 | 305 | // Check the type 306 | k := fv.Kind() 307 | size, ok := IsIntType(k) 308 | if !ok { 309 | return false, nil, fmt.Errorf("field %q is of type %v; want an int", i.Field, k) 310 | } 311 | 312 | // Get the value and encode it 313 | val := fv.Int() 314 | buf := encodeInt(val, size) 315 | 316 | return true, buf, nil 317 | } 318 | 319 | func (i *IntFieldIndex) FromArgs(args ...interface{}) ([]byte, error) { 320 | if len(args) != 1 { 321 | return nil, fmt.Errorf("must provide only a single argument") 322 | } 323 | 324 | v := reflect.ValueOf(args[0]) 325 | if !v.IsValid() { 326 | return nil, fmt.Errorf("%#v is invalid", args[0]) 327 | } 328 | 329 | k := v.Kind() 330 | size, ok := IsIntType(k) 331 | if !ok { 332 | return nil, fmt.Errorf("arg is of type %v; want a int", k) 333 | } 334 | 335 | val := v.Int() 336 | buf := encodeInt(val, size) 337 | 338 | return buf, nil 339 | } 340 | 341 | func encodeInt(val int64, size int) []byte { 342 | buf := make([]byte, size) 343 | 344 | // This bit flips the sign bit on any sized signed twos-complement integer, 345 | // which when truncated to a uint of the same size will bias the value such 346 | // that the maximum negative int becomes 0, and the maximum positive int 347 | // becomes the maximum positive uint. 348 | scaled := val ^ int64(-1<<(size*8-1)) 349 | 350 | switch size { 351 | case 1: 352 | buf[0] = uint8(scaled) 353 | case 2: 354 | binary.BigEndian.PutUint16(buf, uint16(scaled)) 355 | case 4: 356 | binary.BigEndian.PutUint32(buf, uint32(scaled)) 357 | case 8: 358 | binary.BigEndian.PutUint64(buf, uint64(scaled)) 359 | default: 360 | panic(fmt.Sprintf("unsupported int size parameter: %d", size)) 361 | } 362 | 363 | return buf 364 | } 365 | 366 | // IsIntType returns whether the passed type is a type of int and the number 367 | // of bytes needed to encode the type. 368 | func IsIntType(k reflect.Kind) (size int, okay bool) { 369 | switch k { 370 | case reflect.Int: 371 | return strconv.IntSize / 8, true 372 | case reflect.Int8: 373 | return 1, true 374 | case reflect.Int16: 375 | return 2, true 376 | case reflect.Int32: 377 | return 4, true 378 | case reflect.Int64: 379 | return 8, true 380 | default: 381 | return 0, false 382 | } 383 | } 384 | 385 | // UintFieldIndex is used to extract a uint field from an object using 386 | // reflection and builds an index on that field. 387 | type UintFieldIndex struct { 388 | Field string 389 | } 390 | 391 | func (u *UintFieldIndex) FromObject(obj interface{}) (bool, []byte, error) { 392 | v := reflect.ValueOf(obj) 393 | v = reflect.Indirect(v) // Dereference the pointer if any 394 | 395 | fv := v.FieldByName(u.Field) 396 | if !fv.IsValid() { 397 | return false, nil, 398 | fmt.Errorf("field '%s' for %#v is invalid", u.Field, obj) 399 | } 400 | 401 | // Check the type 402 | k := fv.Kind() 403 | size, ok := IsUintType(k) 404 | if !ok { 405 | return false, nil, fmt.Errorf("field %q is of type %v; want a uint", u.Field, k) 406 | } 407 | 408 | // Get the value and encode it 409 | val := fv.Uint() 410 | buf := encodeUInt(val, size) 411 | 412 | return true, buf, nil 413 | } 414 | 415 | func (u *UintFieldIndex) FromArgs(args ...interface{}) ([]byte, error) { 416 | if len(args) != 1 { 417 | return nil, fmt.Errorf("must provide only a single argument") 418 | } 419 | 420 | v := reflect.ValueOf(args[0]) 421 | if !v.IsValid() { 422 | return nil, fmt.Errorf("%#v is invalid", args[0]) 423 | } 424 | 425 | k := v.Kind() 426 | size, ok := IsUintType(k) 427 | if !ok { 428 | return nil, fmt.Errorf("arg is of type %v; want a uint", k) 429 | } 430 | 431 | val := v.Uint() 432 | buf := encodeUInt(val, size) 433 | 434 | return buf, nil 435 | } 436 | 437 | func encodeUInt(val uint64, size int) []byte { 438 | buf := make([]byte, size) 439 | 440 | switch size { 441 | case 1: 442 | buf[0] = uint8(val) 443 | case 2: 444 | binary.BigEndian.PutUint16(buf, uint16(val)) 445 | case 4: 446 | binary.BigEndian.PutUint32(buf, uint32(val)) 447 | case 8: 448 | binary.BigEndian.PutUint64(buf, val) 449 | default: 450 | panic(fmt.Sprintf("unsupported uint size parameter: %d", size)) 451 | } 452 | 453 | return buf 454 | } 455 | 456 | // IsUintType returns whether the passed type is a type of uint and the number 457 | // of bytes needed to encode the type. 458 | func IsUintType(k reflect.Kind) (size int, okay bool) { 459 | switch k { 460 | case reflect.Uint: 461 | return strconv.IntSize / 8, true 462 | case reflect.Uint8: 463 | return 1, true 464 | case reflect.Uint16: 465 | return 2, true 466 | case reflect.Uint32: 467 | return 4, true 468 | case reflect.Uint64: 469 | return 8, true 470 | default: 471 | return 0, false 472 | } 473 | } 474 | 475 | // BoolFieldIndex is used to extract an boolean field from an object using 476 | // reflection and builds an index on that field. 477 | type BoolFieldIndex struct { 478 | Field string 479 | } 480 | 481 | func (i *BoolFieldIndex) FromObject(obj interface{}) (bool, []byte, error) { 482 | v := reflect.ValueOf(obj) 483 | v = reflect.Indirect(v) // Dereference the pointer if any 484 | 485 | fv := v.FieldByName(i.Field) 486 | if !fv.IsValid() { 487 | return false, nil, 488 | fmt.Errorf("field '%s' for %#v is invalid", i.Field, obj) 489 | } 490 | 491 | // Check the type 492 | k := fv.Kind() 493 | if k != reflect.Bool { 494 | return false, nil, fmt.Errorf("field %q is of type %v; want a bool", i.Field, k) 495 | } 496 | 497 | // Get the value and encode it 498 | buf := make([]byte, 1) 499 | if fv.Bool() { 500 | buf[0] = 1 501 | } 502 | 503 | return true, buf, nil 504 | } 505 | 506 | func (i *BoolFieldIndex) FromArgs(args ...interface{}) ([]byte, error) { 507 | return fromBoolArgs(args) 508 | } 509 | 510 | // UUIDFieldIndex is used to extract a field from an object 511 | // using reflection and builds an index on that field by treating 512 | // it as a UUID. This is an optimization to using a StringFieldIndex 513 | // as the UUID can be more compactly represented in byte form. 514 | type UUIDFieldIndex struct { 515 | Field string 516 | } 517 | 518 | func (u *UUIDFieldIndex) FromObject(obj interface{}) (bool, []byte, error) { 519 | v := reflect.ValueOf(obj) 520 | v = reflect.Indirect(v) // Dereference the pointer if any 521 | 522 | fv := v.FieldByName(u.Field) 523 | if !fv.IsValid() { 524 | return false, nil, 525 | fmt.Errorf("field '%s' for %#v is invalid", u.Field, obj) 526 | } 527 | 528 | val := fv.String() 529 | if val == "" { 530 | return false, nil, nil 531 | } 532 | 533 | buf, err := u.parseString(val, true) 534 | return true, buf, err 535 | } 536 | 537 | func (u *UUIDFieldIndex) FromArgs(args ...interface{}) ([]byte, error) { 538 | if len(args) != 1 { 539 | return nil, fmt.Errorf("must provide only a single argument") 540 | } 541 | switch arg := args[0].(type) { 542 | case string: 543 | return u.parseString(arg, true) 544 | case []byte: 545 | if len(arg) != 16 { 546 | return nil, fmt.Errorf("byte slice must be 16 characters") 547 | } 548 | return arg, nil 549 | default: 550 | return nil, 551 | fmt.Errorf("argument must be a string or byte slice: %#v", args[0]) 552 | } 553 | } 554 | 555 | func (u *UUIDFieldIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) { 556 | if len(args) != 1 { 557 | return nil, fmt.Errorf("must provide only a single argument") 558 | } 559 | switch arg := args[0].(type) { 560 | case string: 561 | return u.parseString(arg, false) 562 | case []byte: 563 | return arg, nil 564 | default: 565 | return nil, 566 | fmt.Errorf("argument must be a string or byte slice: %#v", args[0]) 567 | } 568 | } 569 | 570 | // parseString parses a UUID from the string. If enforceLength is false, it will 571 | // parse a partial UUID. An error is returned if the input, stripped of hyphens, 572 | // is not even length. 573 | func (u *UUIDFieldIndex) parseString(s string, enforceLength bool) ([]byte, error) { 574 | // Verify the length 575 | l := len(s) 576 | if enforceLength && l != 36 { 577 | return nil, fmt.Errorf("UUID must be 36 characters") 578 | } else if l > 36 { 579 | return nil, fmt.Errorf("Invalid UUID length. UUID have 36 characters; got %d", l) 580 | } 581 | 582 | hyphens := strings.Count(s, "-") 583 | if hyphens > 4 { 584 | return nil, fmt.Errorf(`UUID should have maximum of 4 "-"; got %d`, hyphens) 585 | } 586 | 587 | // The sanitized length is the length of the original string without the "-". 588 | sanitized := strings.Replace(s, "-", "", -1) 589 | sanitizedLength := len(sanitized) 590 | if sanitizedLength%2 != 0 { 591 | return nil, fmt.Errorf("Input (without hyphens) must be even length") 592 | } 593 | 594 | dec, err := hex.DecodeString(sanitized) 595 | if err != nil { 596 | return nil, fmt.Errorf("Invalid UUID: %v", err) 597 | } 598 | 599 | return dec, nil 600 | } 601 | 602 | // FieldSetIndex is used to extract a field from an object using reflection and 603 | // builds an index on whether the field is set by comparing it against its 604 | // type's nil value. 605 | type FieldSetIndex struct { 606 | Field string 607 | } 608 | 609 | func (f *FieldSetIndex) FromObject(obj interface{}) (bool, []byte, error) { 610 | v := reflect.ValueOf(obj) 611 | v = reflect.Indirect(v) // Dereference the pointer if any 612 | 613 | fv := v.FieldByName(f.Field) 614 | if !fv.IsValid() { 615 | return false, nil, 616 | fmt.Errorf("field '%s' for %#v is invalid", f.Field, obj) 617 | } 618 | 619 | if fv.Interface() == reflect.Zero(fv.Type()).Interface() { 620 | return true, []byte{0}, nil 621 | } 622 | 623 | return true, []byte{1}, nil 624 | } 625 | 626 | func (f *FieldSetIndex) FromArgs(args ...interface{}) ([]byte, error) { 627 | return fromBoolArgs(args) 628 | } 629 | 630 | // ConditionalIndex builds an index based on a condition specified by a passed 631 | // user function. This function may examine the passed object and return a 632 | // boolean to encapsulate an arbitrarily complex conditional. 633 | type ConditionalIndex struct { 634 | Conditional ConditionalIndexFunc 635 | } 636 | 637 | // ConditionalIndexFunc is the required function interface for a 638 | // ConditionalIndex. 639 | type ConditionalIndexFunc func(obj interface{}) (bool, error) 640 | 641 | func (c *ConditionalIndex) FromObject(obj interface{}) (bool, []byte, error) { 642 | // Call the user's function 643 | res, err := c.Conditional(obj) 644 | if err != nil { 645 | return false, nil, fmt.Errorf("ConditionalIndexFunc(%#v) failed: %v", obj, err) 646 | } 647 | 648 | if res { 649 | return true, []byte{1}, nil 650 | } 651 | 652 | return true, []byte{0}, nil 653 | } 654 | 655 | func (c *ConditionalIndex) FromArgs(args ...interface{}) ([]byte, error) { 656 | return fromBoolArgs(args) 657 | } 658 | 659 | // fromBoolArgs is a helper that expects only a single boolean argument and 660 | // returns a single length byte array containing either a one or zero depending 661 | // on whether the passed input is true or false respectively. 662 | func fromBoolArgs(args []interface{}) ([]byte, error) { 663 | if len(args) != 1 { 664 | return nil, fmt.Errorf("must provide only a single argument") 665 | } 666 | 667 | if val, ok := args[0].(bool); !ok { 668 | return nil, fmt.Errorf("argument must be a boolean type: %#v", args[0]) 669 | } else if val { 670 | return []byte{1}, nil 671 | } 672 | 673 | return []byte{0}, nil 674 | } 675 | 676 | // CompoundIndex is used to build an index using multiple sub-indexes 677 | // Prefix based iteration is supported as long as the appropriate prefix 678 | // of indexers support it. All sub-indexers are only assumed to expect 679 | // a single argument. 680 | type CompoundIndex struct { 681 | Indexes []Indexer 682 | 683 | // AllowMissing results in an index based on only the indexers 684 | // that return data. If true, you may end up with 2/3 columns 685 | // indexed which might be useful for an index scan. Otherwise, 686 | // the CompoundIndex requires all indexers to be satisfied. 687 | AllowMissing bool 688 | } 689 | 690 | func (c *CompoundIndex) FromObject(raw interface{}) (bool, []byte, error) { 691 | var out []byte 692 | for i, idxRaw := range c.Indexes { 693 | idx, ok := idxRaw.(SingleIndexer) 694 | if !ok { 695 | return false, nil, fmt.Errorf("sub-index %d error: %s", i, "sub-index must be a SingleIndexer") 696 | } 697 | ok, val, err := idx.FromObject(raw) 698 | if err != nil { 699 | return false, nil, fmt.Errorf("sub-index %d error: %v", i, err) 700 | } 701 | if !ok { 702 | if c.AllowMissing { 703 | break 704 | } else { 705 | return false, nil, nil 706 | } 707 | } 708 | out = append(out, val...) 709 | } 710 | return true, out, nil 711 | } 712 | 713 | func (c *CompoundIndex) FromArgs(args ...interface{}) ([]byte, error) { 714 | if len(args) != len(c.Indexes) { 715 | return nil, fmt.Errorf("non-equivalent argument count and index fields") 716 | } 717 | var out []byte 718 | for i, arg := range args { 719 | val, err := c.Indexes[i].FromArgs(arg) 720 | if err != nil { 721 | return nil, fmt.Errorf("sub-index %d error: %v", i, err) 722 | } 723 | out = append(out, val...) 724 | } 725 | return out, nil 726 | } 727 | 728 | func (c *CompoundIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) { 729 | if len(args) > len(c.Indexes) { 730 | return nil, fmt.Errorf("more arguments than index fields") 731 | } 732 | var out []byte 733 | for i, arg := range args { 734 | if i+1 < len(args) { 735 | val, err := c.Indexes[i].FromArgs(arg) 736 | if err != nil { 737 | return nil, fmt.Errorf("sub-index %d error: %v", i, err) 738 | } 739 | out = append(out, val...) 740 | } else { 741 | prefixIndexer, ok := c.Indexes[i].(PrefixIndexer) 742 | if !ok { 743 | return nil, fmt.Errorf("sub-index %d does not support prefix scanning", i) 744 | } 745 | val, err := prefixIndexer.PrefixFromArgs(arg) 746 | if err != nil { 747 | return nil, fmt.Errorf("sub-index %d error: %v", i, err) 748 | } 749 | out = append(out, val...) 750 | } 751 | } 752 | return out, nil 753 | } 754 | 755 | // CompoundMultiIndex is used to build an index using multiple 756 | // sub-indexes. 757 | // 758 | // Unlike CompoundIndex, CompoundMultiIndex can have both 759 | // SingleIndexer and MultiIndexer sub-indexers. However, each 760 | // MultiIndexer adds considerable overhead/complexity in terms of 761 | // the number of indexes created under-the-hood. It is not suggested 762 | // to use more than one or two, if possible. 763 | // 764 | // Another change from CompoundIndexer is that if AllowMissing is 765 | // set, not only is it valid to have empty index fields, but it will 766 | // still create index values up to the first empty index. This means 767 | // that if you have a value with an empty field, rather than using a 768 | // prefix for lookup, you can simply pass in less arguments. As an 769 | // example, if {Foo, Bar} is indexed but Bar is missing for a value 770 | // and AllowMissing is set, an index will still be created for {Foo} 771 | // and it is valid to do a lookup passing in only Foo as an argument. 772 | // Note that the ordering isn't guaranteed -- it's last-insert wins, 773 | // but this is true if you have two objects that have the same 774 | // indexes not using AllowMissing anyways. 775 | // 776 | // Because StringMapFieldIndexers can take a varying number of args, 777 | // it is currently a requirement that whenever it is used, two 778 | // arguments must _always_ be provided for it. In theory we only 779 | // need one, except a bug in that indexer means the single-argument 780 | // version will never work. You can leave the second argument nil, 781 | // but it will never produce a value. We support this for whenever 782 | // that bug is fixed, likely in a next major version bump. 783 | // 784 | // Prefix-based indexing is not currently supported. 785 | type CompoundMultiIndex struct { 786 | Indexes []Indexer 787 | 788 | // AllowMissing results in an index based on only the indexers 789 | // that return data. If true, you may end up with 2/3 columns 790 | // indexed which might be useful for an index scan. Otherwise, 791 | // CompoundMultiIndex requires all indexers to be satisfied. 792 | AllowMissing bool 793 | } 794 | 795 | func (c *CompoundMultiIndex) FromObject(raw interface{}) (bool, [][]byte, error) { 796 | // At each entry, builder is storing the results from the next index 797 | builder := make([][][]byte, 0, len(c.Indexes)) 798 | 799 | forloop: 800 | // This loop goes through each indexer and adds the value(s) provided to the next 801 | // entry in the slice. We can then later walk it like a tree to construct the indices. 802 | for i, idxRaw := range c.Indexes { 803 | switch idx := idxRaw.(type) { 804 | case SingleIndexer: 805 | ok, val, err := idx.FromObject(raw) 806 | if err != nil { 807 | return false, nil, fmt.Errorf("single sub-index %d error: %v", i, err) 808 | } 809 | if !ok { 810 | if c.AllowMissing { 811 | break forloop 812 | } else { 813 | return false, nil, nil 814 | } 815 | } 816 | builder = append(builder, [][]byte{val}) 817 | 818 | case MultiIndexer: 819 | ok, vals, err := idx.FromObject(raw) 820 | if err != nil { 821 | return false, nil, fmt.Errorf("multi sub-index %d error: %v", i, err) 822 | } 823 | if !ok { 824 | if c.AllowMissing { 825 | break forloop 826 | } else { 827 | return false, nil, nil 828 | } 829 | } 830 | 831 | // Add each of the new values to each of the old values 832 | builder = append(builder, vals) 833 | 834 | default: 835 | return false, nil, fmt.Errorf("sub-index %d does not satisfy either SingleIndexer or MultiIndexer", i) 836 | } 837 | } 838 | 839 | // Start with something higher to avoid resizing if possible 840 | out := make([][]byte, 0, len(c.Indexes)^3) 841 | 842 | // We are walking through the builder slice essentially in a depth-first fashion, 843 | // building the prefix and leaves as we go. If AllowMissing is false, we only insert 844 | // these full paths to leaves. Otherwise, we also insert each prefix along the way. 845 | // This allows for lookup in FromArgs when AllowMissing is true that does not contain 846 | // the full set of arguments. e.g. for {Foo, Bar} where an object has only the Foo 847 | // field specified as "abc", it is valid to call FromArgs with just "abc". 848 | var walkVals func([]byte, int) 849 | walkVals = func(currPrefix []byte, depth int) { 850 | if depth >= len(builder) { 851 | return 852 | } 853 | 854 | if depth == len(builder)-1 { 855 | // These are the "leaves", so append directly 856 | for _, v := range builder[depth] { 857 | outcome := make([]byte, len(currPrefix)) 858 | copy(outcome, currPrefix) 859 | out = append(out, append(outcome, v...)) 860 | } 861 | return 862 | } 863 | for _, v := range builder[depth] { 864 | nextPrefix := append(currPrefix, v...) 865 | if c.AllowMissing { 866 | out = append(out, nextPrefix) 867 | } 868 | walkVals(nextPrefix, depth+1) 869 | } 870 | } 871 | 872 | walkVals(nil, 0) 873 | 874 | return true, out, nil 875 | } 876 | 877 | func (c *CompoundMultiIndex) FromArgs(args ...interface{}) ([]byte, error) { 878 | var stringMapCount int 879 | var argCount int 880 | for _, index := range c.Indexes { 881 | if argCount >= len(args) { 882 | break 883 | } 884 | if _, ok := index.(*StringMapFieldIndex); ok { 885 | // We require pairs for StringMapFieldIndex, but only got one 886 | if argCount+1 >= len(args) { 887 | return nil, errors.New("invalid number of arguments") 888 | } 889 | stringMapCount++ 890 | argCount += 2 891 | } else { 892 | argCount++ 893 | } 894 | } 895 | argCount = 0 896 | 897 | switch c.AllowMissing { 898 | case true: 899 | if len(args) > len(c.Indexes)+stringMapCount { 900 | return nil, errors.New("too many arguments") 901 | } 902 | 903 | default: 904 | if len(args) != len(c.Indexes)+stringMapCount { 905 | return nil, errors.New("number of arguments does not equal number of indexers") 906 | } 907 | } 908 | 909 | var out []byte 910 | var val []byte 911 | var err error 912 | for i, idx := range c.Indexes { 913 | if argCount >= len(args) { 914 | // We're done; should only hit this if AllowMissing 915 | break 916 | } 917 | if _, ok := idx.(*StringMapFieldIndex); ok { 918 | if args[argCount+1] == nil { 919 | val, err = idx.FromArgs(args[argCount]) 920 | } else { 921 | val, err = idx.FromArgs(args[argCount : argCount+2]...) 922 | } 923 | argCount += 2 924 | } else { 925 | val, err = idx.FromArgs(args[argCount]) 926 | argCount++ 927 | } 928 | if err != nil { 929 | return nil, fmt.Errorf("sub-index %d error: %v", i, err) 930 | } 931 | out = append(out, val...) 932 | } 933 | return out, nil 934 | } 935 | -------------------------------------------------------------------------------- /index_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package memdb 5 | 6 | import ( 7 | "bytes" 8 | crand "crypto/rand" 9 | "encoding/binary" 10 | "fmt" 11 | "math/rand" 12 | "reflect" 13 | "sort" 14 | "strings" 15 | "testing" 16 | "testing/quick" 17 | "time" 18 | ) 19 | 20 | type TestObject struct { 21 | ID string 22 | Foo string 23 | Fu *string 24 | Boo *string 25 | Bar int 26 | Baz string 27 | Bam *bool 28 | Empty string 29 | Qux []string 30 | QuxEmpty []string 31 | Zod map[string]string 32 | ZodEmpty map[string]string 33 | Int int 34 | Int8 int8 35 | Int16 int16 36 | Int32 int32 37 | Int64 int64 38 | Uint uint 39 | Uint8 uint8 40 | Uint16 uint16 41 | Uint32 uint32 42 | Uint64 uint64 43 | Bool bool 44 | } 45 | 46 | func String(s string) *string { 47 | return &s 48 | } 49 | 50 | func testObj() *TestObject { 51 | b := true 52 | obj := &TestObject{ 53 | ID: "my-cool-obj", 54 | Foo: "Testing", 55 | Fu: String("Fu"), 56 | Boo: nil, 57 | Bar: 42, 58 | Baz: "yep", 59 | Bam: &b, 60 | Qux: []string{"Test", "Test2"}, 61 | Zod: map[string]string{ 62 | "Role": "Server", 63 | "instance_type": "m3.medium", 64 | "": "asdf", 65 | }, 66 | Int: int(1), 67 | Int8: int8(-1 << 7), 68 | Int16: int16(-1 << 15), 69 | Int32: int32(-1 << 31), 70 | Int64: int64(-1 << 63), 71 | Uint: uint(1), 72 | Uint8: uint8(1<<8 - 1), 73 | Uint16: uint16(1<<16 - 1), 74 | Uint32: uint32(1<<32 - 1), 75 | Uint64: uint64(1<<64 - 1), 76 | Bool: false, 77 | } 78 | return obj 79 | } 80 | 81 | func TestStringFieldIndex_FromObject(t *testing.T) { 82 | obj := testObj() 83 | indexer := StringFieldIndex{"Foo", false} 84 | 85 | ok, val, err := indexer.FromObject(obj) 86 | if err != nil { 87 | t.Fatalf("err: %v", err) 88 | } 89 | if string(val) != "Testing\x00" { 90 | t.Fatalf("bad: %s", val) 91 | } 92 | if !ok { 93 | t.Fatalf("should be ok") 94 | } 95 | 96 | lower := StringFieldIndex{"Foo", true} 97 | ok, val, err = lower.FromObject(obj) 98 | if err != nil { 99 | t.Fatalf("err: %v", err) 100 | } 101 | if string(val) != "testing\x00" { 102 | t.Fatalf("bad: %s", val) 103 | } 104 | if !ok { 105 | t.Fatalf("should be ok") 106 | } 107 | 108 | badField := StringFieldIndex{"NA", true} 109 | ok, val, err = badField.FromObject(obj) 110 | if err == nil { 111 | t.Fatalf("should get error") 112 | } 113 | 114 | emptyField := StringFieldIndex{"Empty", true} 115 | ok, val, err = emptyField.FromObject(obj) 116 | if err != nil { 117 | t.Fatalf("err: %v", err) 118 | } 119 | if ok { 120 | t.Fatalf("should not ok") 121 | } 122 | 123 | pointerField := StringFieldIndex{"Fu", false} 124 | ok, val, err = pointerField.FromObject(obj) 125 | if err != nil { 126 | t.Fatalf("err: %v", err) 127 | } 128 | if string(val) != "Fu\x00" { 129 | t.Fatalf("bad: %s", val) 130 | } 131 | if !ok { 132 | t.Fatalf("should be ok") 133 | } 134 | 135 | pointerField = StringFieldIndex{"Boo", false} 136 | ok, val, err = pointerField.FromObject(obj) 137 | if err != nil { 138 | t.Fatalf("err: %v", err) 139 | } 140 | if string(val) != "" { 141 | t.Fatalf("bad: %s", val) 142 | } 143 | if ok { 144 | t.Fatalf("should be not ok") 145 | } 146 | } 147 | 148 | func TestStringFieldIndex_FromArgs(t *testing.T) { 149 | indexer := StringFieldIndex{"Foo", false} 150 | _, err := indexer.FromArgs() 151 | if err == nil { 152 | t.Fatalf("should get err") 153 | } 154 | 155 | _, err = indexer.FromArgs(42) 156 | if err == nil { 157 | t.Fatalf("should get err") 158 | } 159 | 160 | val, err := indexer.FromArgs("foo") 161 | if err != nil { 162 | t.Fatalf("err: %v", err) 163 | } 164 | if string(val) != "foo\x00" { 165 | t.Fatalf("foo") 166 | } 167 | 168 | lower := StringFieldIndex{"Foo", true} 169 | val, err = lower.FromArgs("Foo") 170 | if err != nil { 171 | t.Fatalf("err: %v", err) 172 | } 173 | if string(val) != "foo\x00" { 174 | t.Fatalf("foo") 175 | } 176 | } 177 | 178 | func TestStringFieldIndex_PrefixFromArgs(t *testing.T) { 179 | indexer := StringFieldIndex{"Foo", false} 180 | _, err := indexer.FromArgs() 181 | if err == nil { 182 | t.Fatalf("should get err") 183 | } 184 | 185 | _, err = indexer.PrefixFromArgs(42) 186 | if err == nil { 187 | t.Fatalf("should get err") 188 | } 189 | 190 | val, err := indexer.PrefixFromArgs("foo") 191 | if err != nil { 192 | t.Fatalf("err: %v", err) 193 | } 194 | if string(val) != "foo" { 195 | t.Fatalf("foo") 196 | } 197 | 198 | lower := StringFieldIndex{"Foo", true} 199 | val, err = lower.PrefixFromArgs("Foo") 200 | if err != nil { 201 | t.Fatalf("err: %v", err) 202 | } 203 | if string(val) != "foo" { 204 | t.Fatalf("foo") 205 | } 206 | } 207 | 208 | func TestStringSliceFieldIndex_FromObject(t *testing.T) { 209 | obj := testObj() 210 | 211 | indexer := StringSliceFieldIndex{"Qux", false} 212 | ok, vals, err := indexer.FromObject(obj) 213 | if err != nil { 214 | t.Fatalf("err: %v", err) 215 | } 216 | if len(vals) != 2 { 217 | t.Fatal("bad result length") 218 | } 219 | if string(vals[0]) != "Test\x00" { 220 | t.Fatalf("bad: %s", vals[0]) 221 | } 222 | if string(vals[1]) != "Test2\x00" { 223 | t.Fatalf("bad: %s", vals[1]) 224 | } 225 | if !ok { 226 | t.Fatalf("should be ok") 227 | } 228 | 229 | lower := StringSliceFieldIndex{"Qux", true} 230 | ok, vals, err = lower.FromObject(obj) 231 | if err != nil { 232 | t.Fatalf("err: %v", err) 233 | } 234 | if len(vals) != 2 { 235 | t.Fatal("bad result length") 236 | } 237 | if string(vals[0]) != "test\x00" { 238 | t.Fatalf("bad: %s", vals[0]) 239 | } 240 | if string(vals[1]) != "test2\x00" { 241 | t.Fatalf("bad: %s", vals[1]) 242 | } 243 | if !ok { 244 | t.Fatalf("should be ok") 245 | } 246 | 247 | badField := StringSliceFieldIndex{"NA", true} 248 | ok, vals, err = badField.FromObject(obj) 249 | if err == nil { 250 | t.Fatalf("should get error") 251 | } 252 | 253 | emptyField := StringSliceFieldIndex{"QuxEmpty", true} 254 | ok, vals, err = emptyField.FromObject(obj) 255 | if err != nil { 256 | t.Fatalf("err: %v", err) 257 | } 258 | if ok { 259 | t.Fatalf("should not ok") 260 | } 261 | } 262 | 263 | func TestStringSliceFieldIndex_FromArgs(t *testing.T) { 264 | indexer := StringSliceFieldIndex{"Qux", false} 265 | _, err := indexer.FromArgs() 266 | if err == nil { 267 | t.Fatalf("should get err") 268 | } 269 | 270 | _, err = indexer.FromArgs(42) 271 | if err == nil { 272 | t.Fatalf("should get err") 273 | } 274 | 275 | val, err := indexer.FromArgs("foo") 276 | if err != nil { 277 | t.Fatalf("err: %v", err) 278 | } 279 | if string(val) != "foo\x00" { 280 | t.Fatalf("foo") 281 | } 282 | 283 | lower := StringSliceFieldIndex{"Qux", true} 284 | val, err = lower.FromArgs("Foo") 285 | if err != nil { 286 | t.Fatalf("err: %v", err) 287 | } 288 | if string(val) != "foo\x00" { 289 | t.Fatalf("foo") 290 | } 291 | } 292 | 293 | func TestStringSliceFieldIndex_PrefixFromArgs(t *testing.T) { 294 | indexer := StringSliceFieldIndex{"Qux", false} 295 | _, err := indexer.FromArgs() 296 | if err == nil { 297 | t.Fatalf("should get err") 298 | } 299 | 300 | _, err = indexer.PrefixFromArgs(42) 301 | if err == nil { 302 | t.Fatalf("should get err") 303 | } 304 | 305 | val, err := indexer.PrefixFromArgs("foo") 306 | if err != nil { 307 | t.Fatalf("err: %v", err) 308 | } 309 | if string(val) != "foo" { 310 | t.Fatalf("foo") 311 | } 312 | 313 | lower := StringSliceFieldIndex{"Qux", true} 314 | val, err = lower.PrefixFromArgs("Foo") 315 | if err != nil { 316 | t.Fatalf("err: %v", err) 317 | } 318 | if string(val) != "foo" { 319 | t.Fatalf("foo") 320 | } 321 | } 322 | 323 | func TestStringMapFieldIndex_FromObject(t *testing.T) { 324 | // Helper function to put the result in a deterministic order 325 | fromObjectSorted := func(index MultiIndexer, obj *TestObject) (bool, []string, error) { 326 | ok, v, err := index.FromObject(obj) 327 | var vals []string 328 | for _, s := range v { 329 | vals = append(vals, string(s)) 330 | } 331 | sort.Strings(vals) 332 | return ok, vals, err 333 | } 334 | 335 | obj := testObj() 336 | 337 | indexer := StringMapFieldIndex{"Zod", false} 338 | ok, vals, err := fromObjectSorted(&indexer, obj) 339 | if err != nil { 340 | t.Fatalf("err: %v", err) 341 | } 342 | if len(vals) != 2 { 343 | t.Fatalf("bad result length of %d", len(vals)) 344 | } 345 | if string(vals[0]) != "Role\x00Server\x00" { 346 | t.Fatalf("bad: %s", vals[0]) 347 | } 348 | if string(vals[1]) != "instance_type\x00m3.medium\x00" { 349 | t.Fatalf("bad: %s", vals[1]) 350 | } 351 | if !ok { 352 | t.Fatalf("should be ok") 353 | } 354 | 355 | lower := StringMapFieldIndex{"Zod", true} 356 | ok, vals, err = fromObjectSorted(&lower, obj) 357 | if err != nil { 358 | t.Fatalf("err: %v", err) 359 | } 360 | if len(vals) != 2 { 361 | t.Fatal("bad result length") 362 | } 363 | if string(vals[0]) != "instance_type\x00m3.medium\x00" { 364 | t.Fatalf("bad: %s", vals[0]) 365 | } 366 | if string(vals[1]) != "role\x00server\x00" { 367 | t.Fatalf("bad: %s", vals[1]) 368 | } 369 | if !ok { 370 | t.Fatalf("should be ok") 371 | } 372 | 373 | badField := StringMapFieldIndex{"NA", true} 374 | ok, _, err = badField.FromObject(obj) 375 | if err == nil { 376 | t.Fatalf("should get error") 377 | } 378 | 379 | emptyField := StringMapFieldIndex{"ZodEmpty", true} 380 | ok, _, err = emptyField.FromObject(obj) 381 | if err != nil { 382 | t.Fatalf("err: %v", err) 383 | } 384 | if ok { 385 | t.Fatalf("should not ok") 386 | } 387 | } 388 | 389 | func TestStringMapFieldIndex_FromArgs(t *testing.T) { 390 | indexer := StringMapFieldIndex{"Zod", false} 391 | _, err := indexer.FromArgs() 392 | if err == nil { 393 | t.Fatalf("should get err") 394 | } 395 | 396 | _, err = indexer.FromArgs(42) 397 | if err == nil { 398 | t.Fatalf("should get err") 399 | } 400 | 401 | val, err := indexer.FromArgs("Role", "Server") 402 | if err != nil { 403 | t.Fatalf("err: %v", err) 404 | } 405 | if string(val) != "Role\x00Server\x00" { 406 | t.Fatalf("bad: %v", string(val)) 407 | } 408 | 409 | lower := StringMapFieldIndex{"Zod", true} 410 | val, err = lower.FromArgs("Role", "Server") 411 | if err != nil { 412 | t.Fatalf("err: %v", err) 413 | } 414 | if string(val) != "role\x00server\x00" { 415 | t.Fatalf("bad: %v", string(val)) 416 | } 417 | } 418 | 419 | func TestUUIDFeldIndex_parseString(t *testing.T) { 420 | u := &UUIDFieldIndex{} 421 | _, err := u.parseString("invalid", true) 422 | if err == nil { 423 | t.Fatalf("should error") 424 | } 425 | 426 | buf, uuid := generateUUID() 427 | 428 | out, err := u.parseString(uuid, true) 429 | if err != nil { 430 | t.Fatalf("err: %v", err) 431 | } 432 | 433 | if !bytes.Equal(out, buf) { 434 | t.Fatalf("bad: %#v %#v", out, buf) 435 | } 436 | 437 | _, err = u.parseString("1-2-3-4-5-6", false) 438 | if err == nil { 439 | t.Fatalf("should error") 440 | } 441 | 442 | // Parse an empty string. 443 | out, err = u.parseString("", false) 444 | if err != nil { 445 | t.Fatalf("err: %v", err) 446 | } 447 | 448 | expected := []byte{} 449 | if !bytes.Equal(out, expected) { 450 | t.Fatalf("bad: %#v %#v", out, expected) 451 | } 452 | 453 | // Parse an odd length UUID. 454 | input := "f23" 455 | out, err = u.parseString(input, false) 456 | if err == nil { 457 | t.Fatalf("expect error") 458 | } 459 | 460 | // Parse an even length UUID with hyphen. 461 | input = "20d8c509-3940-" 462 | out, err = u.parseString(input, false) 463 | if err != nil { 464 | t.Fatalf("err: %v", err) 465 | } 466 | 467 | expected = []byte{0x20, 0xd8, 0xc5, 0x09, 0x39, 0x40} 468 | if !bytes.Equal(out, expected) { 469 | t.Fatalf("bad: %#v %#v", out, expected) 470 | } 471 | } 472 | 473 | func TestUUIDFieldIndex_FromObject(t *testing.T) { 474 | obj := testObj() 475 | uuidBuf, uuid := generateUUID() 476 | obj.Foo = uuid 477 | indexer := &UUIDFieldIndex{"Foo"} 478 | 479 | ok, val, err := indexer.FromObject(obj) 480 | if err != nil { 481 | t.Fatalf("err: %v", err) 482 | } 483 | if !bytes.Equal(uuidBuf, val) { 484 | t.Fatalf("bad: %s", val) 485 | } 486 | if !ok { 487 | t.Fatalf("should be ok") 488 | } 489 | 490 | badField := &UUIDFieldIndex{"NA"} 491 | ok, val, err = badField.FromObject(obj) 492 | if err == nil { 493 | t.Fatalf("should get error") 494 | } 495 | 496 | emptyField := &UUIDFieldIndex{"Empty"} 497 | ok, val, err = emptyField.FromObject(obj) 498 | if err != nil { 499 | t.Fatalf("err: %v", err) 500 | } 501 | if ok { 502 | t.Fatalf("should not ok") 503 | } 504 | } 505 | 506 | func TestUUIDFieldIndex_FromArgs(t *testing.T) { 507 | indexer := &UUIDFieldIndex{"Foo"} 508 | _, err := indexer.FromArgs() 509 | if err == nil { 510 | t.Fatalf("should get err") 511 | } 512 | 513 | _, err = indexer.FromArgs(42) 514 | if err == nil { 515 | t.Fatalf("should get err") 516 | } 517 | 518 | uuidBuf, uuid := generateUUID() 519 | 520 | val, err := indexer.FromArgs(uuid) 521 | if err != nil { 522 | t.Fatalf("err: %v", err) 523 | } 524 | if !bytes.Equal(uuidBuf, val) { 525 | t.Fatalf("foo") 526 | } 527 | 528 | val, err = indexer.FromArgs(uuidBuf) 529 | if err != nil { 530 | t.Fatalf("err: %v", err) 531 | } 532 | if !bytes.Equal(uuidBuf, val) { 533 | t.Fatalf("foo") 534 | } 535 | } 536 | 537 | func TestUUIDFieldIndex_PrefixFromArgs(t *testing.T) { 538 | indexer := UUIDFieldIndex{"Foo"} 539 | _, err := indexer.FromArgs() 540 | if err == nil { 541 | t.Fatalf("should get err") 542 | } 543 | 544 | _, err = indexer.PrefixFromArgs(42) 545 | if err == nil { 546 | t.Fatalf("should get err") 547 | } 548 | 549 | uuidBuf, uuid := generateUUID() 550 | 551 | // Test full length. 552 | val, err := indexer.PrefixFromArgs(uuid) 553 | if err != nil { 554 | t.Fatalf("err: %v", err) 555 | } 556 | if !bytes.Equal(uuidBuf, val) { 557 | t.Fatalf("foo") 558 | } 559 | 560 | val, err = indexer.PrefixFromArgs(uuidBuf) 561 | if err != nil { 562 | t.Fatalf("err: %v", err) 563 | } 564 | if !bytes.Equal(uuidBuf, val) { 565 | t.Fatalf("foo") 566 | } 567 | 568 | // Test partial. 569 | val, err = indexer.PrefixFromArgs(uuid[:6]) 570 | if err != nil { 571 | t.Fatalf("err: %v", err) 572 | } 573 | if !bytes.Equal(uuidBuf[:3], val) { 574 | t.Fatalf("PrefixFromArgs returned %#v;\nwant %#v", val, uuidBuf[:3]) 575 | } 576 | 577 | val, err = indexer.PrefixFromArgs(uuidBuf[:9]) 578 | if err != nil { 579 | t.Fatalf("err: %v", err) 580 | } 581 | if !bytes.Equal(uuidBuf[:9], val) { 582 | t.Fatalf("foo") 583 | } 584 | } 585 | 586 | func BenchmarkUUIDFieldIndex_parseString(b *testing.B) { 587 | _, uuid := generateUUID() 588 | indexer := &UUIDFieldIndex{} 589 | for i := 0; i < b.N; i++ { 590 | _, err := indexer.parseString(uuid, true) 591 | if err != nil { 592 | b.FailNow() 593 | } 594 | } 595 | } 596 | 597 | func generateUUID() ([]byte, string) { 598 | buf := make([]byte, 16) 599 | if _, err := crand.Read(buf); err != nil { 600 | panic(fmt.Errorf("failed to read random bytes: %v", err)) 601 | } 602 | uuid := fmt.Sprintf("%08x-%04x-%04x-%04x-%12x", 603 | buf[0:4], 604 | buf[4:6], 605 | buf[6:8], 606 | buf[8:10], 607 | buf[10:16]) 608 | return buf, uuid 609 | } 610 | 611 | func TestIntFieldIndex_FromObject(t *testing.T) { 612 | obj := testObj() 613 | 614 | eint := make([]byte, 8) 615 | eint8 := make([]byte, 1) 616 | eint16 := make([]byte, 2) 617 | eint32 := make([]byte, 4) 618 | eint64 := make([]byte, 8) 619 | binary.BigEndian.PutUint64(eint, 1<<63+1) 620 | eint8[0] = 0 621 | binary.BigEndian.PutUint16(eint16, 0) 622 | binary.BigEndian.PutUint32(eint32, 0) 623 | binary.BigEndian.PutUint64(eint64, 0) 624 | 625 | cases := []struct { 626 | Field string 627 | Expected []byte 628 | ErrorContains string 629 | }{ 630 | { 631 | Field: "Int", 632 | Expected: eint, 633 | }, 634 | { 635 | Field: "Int8", 636 | Expected: eint8, 637 | }, 638 | { 639 | Field: "Int16", 640 | Expected: eint16, 641 | }, 642 | { 643 | Field: "Int32", 644 | Expected: eint32, 645 | }, 646 | { 647 | Field: "Int64", 648 | Expected: eint64, 649 | }, 650 | { 651 | Field: "IntGarbage", 652 | ErrorContains: "is invalid", 653 | }, 654 | { 655 | Field: "ID", 656 | ErrorContains: "want an int", 657 | }, 658 | } 659 | 660 | for _, c := range cases { 661 | t.Run(c.Field, func(t *testing.T) { 662 | indexer := IntFieldIndex{c.Field} 663 | ok, val, err := indexer.FromObject(obj) 664 | if err != nil { 665 | if ok { 666 | t.Fatalf("okay and error") 667 | } 668 | 669 | if c.ErrorContains != "" && strings.Contains(err.Error(), c.ErrorContains) { 670 | return 671 | } else { 672 | t.Fatalf("Unexpected error %v", err) 673 | } 674 | } 675 | 676 | if !ok { 677 | t.Fatalf("not okay and no error") 678 | } 679 | 680 | if !bytes.Equal(val, c.Expected) { 681 | t.Fatalf("bad: %#v %#v", val, c.Expected) 682 | } 683 | }) 684 | } 685 | } 686 | 687 | func TestIntFieldIndex_FromArgs(t *testing.T) { 688 | indexer := IntFieldIndex{"Foo"} 689 | _, err := indexer.FromArgs() 690 | if err == nil { 691 | t.Fatalf("should get err") 692 | } 693 | 694 | _, err = indexer.FromArgs(int(1), int(2)) 695 | if err == nil { 696 | t.Fatalf("should get err") 697 | } 698 | 699 | _, err = indexer.FromArgs("foo") 700 | if err == nil { 701 | t.Fatalf("should get err") 702 | } 703 | 704 | obj := testObj() 705 | eint := make([]byte, 8) 706 | eint8 := make([]byte, 1) 707 | eint16 := make([]byte, 2) 708 | eint32 := make([]byte, 4) 709 | eint64 := make([]byte, 8) 710 | binary.BigEndian.PutUint64(eint, 1<<63+1) 711 | eint8[0] = 0 712 | binary.BigEndian.PutUint16(eint16, 0) 713 | binary.BigEndian.PutUint32(eint32, 0) 714 | binary.BigEndian.PutUint64(eint64, 0) 715 | 716 | val, err := indexer.FromArgs(obj.Int) 717 | if err != nil { 718 | t.Fatalf("bad: %v", err) 719 | } 720 | if !bytes.Equal(val, eint) { 721 | t.Fatalf("bad: %#v %#v", val, eint) 722 | } 723 | 724 | val, err = indexer.FromArgs(obj.Int8) 725 | if err != nil { 726 | t.Fatalf("bad: %v", err) 727 | } 728 | if !bytes.Equal(val, eint8) { 729 | t.Fatalf("bad: %#v %#v", val, eint8) 730 | } 731 | 732 | val, err = indexer.FromArgs(obj.Int16) 733 | if err != nil { 734 | t.Fatalf("bad: %v", err) 735 | } 736 | if !bytes.Equal(val, eint16) { 737 | t.Fatalf("bad: %#v %#v", val, eint16) 738 | } 739 | 740 | val, err = indexer.FromArgs(obj.Int32) 741 | if err != nil { 742 | t.Fatalf("bad: %v", err) 743 | } 744 | if !bytes.Equal(val, eint32) { 745 | t.Fatalf("bad: %#v %#v", val, eint32) 746 | } 747 | 748 | val, err = indexer.FromArgs(obj.Int64) 749 | if err != nil { 750 | t.Fatalf("bad: %v", err) 751 | } 752 | if !bytes.Equal(val, eint64) { 753 | t.Fatalf("bad: %#v %#v", val, eint64) 754 | } 755 | } 756 | 757 | func TestIntFieldIndexSortability(t *testing.T) { 758 | testCases := []struct { 759 | i8l int8 760 | i8r int8 761 | i16l int16 762 | i16r int16 763 | i32l int32 764 | i32r int32 765 | i64l int64 766 | i64r int64 767 | il int 768 | ir int 769 | expected int 770 | name string 771 | }{ 772 | {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, "zero"}, 773 | {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, "small eq"}, 774 | {0, 1, 0, 1, 0, 1, 0, 1, 0, 1, -1, "small lt"}, 775 | {2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1, "small gt"}, 776 | {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, "small neg eq"}, 777 | {-2, -1, -2, -1, -2, -1, -2, -1, -2, -1, -1, "small neg lt"}, 778 | {-1, -2, -1, -2, -1, -2, -1, -2, -1, -2, 1, "small neg gt"}, 779 | {-1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, "neg vs pos"}, 780 | {-128, 127, -32768, 32767, -2147483648, 2147483647, -9223372036854775808, 9223372036854775807, -9223372036854775808, 9223372036854775807, -1, "max conditions"}, 781 | {100, 127, 1000, 2000, 1000000000, 2000000000, 10000000000, 20000000000, 1000000000, 2000000000, -1, "large lt"}, 782 | {100, 99, 1000, 999, 1000000000, 999999999, 10000000000, 9999999999, 1000000000, 999999999, 1, "large gt"}, 783 | {126, 127, 255, 256, 65535, 65536, 4294967295, 4294967296, 65535, 65536, -1, "edge conditions"}, 784 | } 785 | 786 | for _, tc := range testCases { 787 | t.Run(tc.name, func(t *testing.T) { 788 | compareEncoded(t, &IntFieldIndex{"Foo"}, tc.i8l, tc.i8r, tc.expected) 789 | compareEncoded(t, &IntFieldIndex{"Foo"}, tc.i16l, tc.i16r, tc.expected) 790 | compareEncoded(t, &IntFieldIndex{"Foo"}, tc.i32l, tc.i32r, tc.expected) 791 | compareEncoded(t, &IntFieldIndex{"Foo"}, tc.i64l, tc.i64r, tc.expected) 792 | compareEncoded(t, &IntFieldIndex{"Foo"}, tc.il, tc.ir, tc.expected) 793 | }) 794 | } 795 | } 796 | 797 | func TestUintFieldIndex_FromObject(t *testing.T) { 798 | obj := testObj() 799 | 800 | euint := make([]byte, 8) 801 | euint8 := make([]byte, 1) 802 | euint16 := make([]byte, 2) 803 | euint32 := make([]byte, 4) 804 | euint64 := make([]byte, 8) 805 | binary.BigEndian.PutUint64(euint, uint64(obj.Uint)) 806 | euint8[0] = obj.Uint8 807 | binary.BigEndian.PutUint16(euint16, obj.Uint16) 808 | binary.BigEndian.PutUint32(euint32, obj.Uint32) 809 | binary.BigEndian.PutUint64(euint64, obj.Uint64) 810 | 811 | cases := []struct { 812 | Field string 813 | Expected []byte 814 | ErrorContains string 815 | }{ 816 | { 817 | Field: "Uint", 818 | Expected: euint, 819 | }, 820 | { 821 | Field: "Uint8", 822 | Expected: euint8, 823 | }, 824 | { 825 | Field: "Uint16", 826 | Expected: euint16, 827 | }, 828 | { 829 | Field: "Uint32", 830 | Expected: euint32, 831 | }, 832 | { 833 | Field: "Uint64", 834 | Expected: euint64, 835 | }, 836 | { 837 | Field: "UintGarbage", 838 | ErrorContains: "is invalid", 839 | }, 840 | { 841 | Field: "ID", 842 | ErrorContains: "want a uint", 843 | }, 844 | } 845 | 846 | for _, c := range cases { 847 | t.Run(c.Field, func(t *testing.T) { 848 | indexer := UintFieldIndex{c.Field} 849 | ok, val, err := indexer.FromObject(obj) 850 | if err != nil { 851 | if ok { 852 | t.Fatalf("okay and error") 853 | } 854 | 855 | if c.ErrorContains != "" && strings.Contains(err.Error(), c.ErrorContains) { 856 | return 857 | } else { 858 | t.Fatalf("Unexpected error %v", err) 859 | } 860 | } 861 | 862 | if !ok { 863 | t.Fatalf("not okay and no error") 864 | } 865 | 866 | if !bytes.Equal(val, c.Expected) { 867 | t.Fatalf("bad: %#v %#v", val, c.Expected) 868 | } 869 | }) 870 | } 871 | } 872 | 873 | func TestUintFieldIndex_FromArgs(t *testing.T) { 874 | indexer := UintFieldIndex{"Foo"} 875 | _, err := indexer.FromArgs() 876 | if err == nil { 877 | t.Fatalf("should get err") 878 | } 879 | 880 | _, err = indexer.FromArgs(uint(1), uint(2)) 881 | if err == nil { 882 | t.Fatalf("should get err") 883 | } 884 | 885 | _, err = indexer.FromArgs("foo") 886 | if err == nil { 887 | t.Fatalf("should get err") 888 | } 889 | 890 | obj := testObj() 891 | euint := make([]byte, 8) 892 | euint8 := make([]byte, 1) 893 | euint16 := make([]byte, 2) 894 | euint32 := make([]byte, 4) 895 | euint64 := make([]byte, 8) 896 | binary.BigEndian.PutUint64(euint, uint64(obj.Uint)) 897 | euint8[0] = obj.Uint8 898 | binary.BigEndian.PutUint16(euint16, obj.Uint16) 899 | binary.BigEndian.PutUint32(euint32, obj.Uint32) 900 | binary.BigEndian.PutUint64(euint64, obj.Uint64) 901 | 902 | val, err := indexer.FromArgs(obj.Uint) 903 | if err != nil { 904 | t.Fatalf("bad: %v", err) 905 | } 906 | if !bytes.Equal(val, euint) { 907 | t.Fatalf("bad: %#v %#v", val, euint) 908 | } 909 | 910 | val, err = indexer.FromArgs(obj.Uint8) 911 | if err != nil { 912 | t.Fatalf("bad: %v", err) 913 | } 914 | if !bytes.Equal(val, euint8) { 915 | t.Fatalf("bad: %#v %#v", val, euint8) 916 | } 917 | 918 | val, err = indexer.FromArgs(obj.Uint16) 919 | if err != nil { 920 | t.Fatalf("bad: %v", err) 921 | } 922 | if !bytes.Equal(val, euint16) { 923 | t.Fatalf("bad: %#v %#v", val, euint16) 924 | } 925 | 926 | val, err = indexer.FromArgs(obj.Uint32) 927 | if err != nil { 928 | t.Fatalf("bad: %v", err) 929 | } 930 | if !bytes.Equal(val, euint32) { 931 | t.Fatalf("bad: %#v %#v", val, euint32) 932 | } 933 | 934 | val, err = indexer.FromArgs(obj.Uint64) 935 | if err != nil { 936 | t.Fatalf("bad: %v", err) 937 | } 938 | if !bytes.Equal(val, euint64) { 939 | t.Fatalf("bad: %#v %#v", val, euint64) 940 | } 941 | } 942 | 943 | func TestUIntFieldIndexSortability(t *testing.T) { 944 | testCases := []struct { 945 | u8l uint8 946 | u8r uint8 947 | u16l uint16 948 | u16r uint16 949 | u32l uint32 950 | u32r uint32 951 | u64l uint64 952 | u64r uint64 953 | ul uint 954 | ur uint 955 | expected int 956 | name string 957 | }{ 958 | {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, "zero"}, 959 | {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, "small eq"}, 960 | {0, 1, 0, 1, 0, 1, 0, 1, 0, 1, -1, "small lt"}, 961 | {2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1, "small gt"}, 962 | {100, 200, 1000, 2000, 1000000000, 2000000000, 10000000000, 20000000000, 1000000000, 2000000000, -1, "large lt"}, 963 | {100, 99, 1000, 999, 1000000000, 999999999, 10000000000, 9999999999, 1000000000, 999999999, 1, "large gt"}, 964 | {127, 128, 255, 256, 65535, 65536, 4294967295, 4294967296, 65535, 65536, -1, "edge conditions"}, 965 | } 966 | 967 | for _, tc := range testCases { 968 | t.Run(tc.name, func(t *testing.T) { 969 | compareEncoded(t, &UintFieldIndex{"Foo"}, tc.u8l, tc.u8r, tc.expected) 970 | compareEncoded(t, &UintFieldIndex{"Foo"}, tc.u16l, tc.u16r, tc.expected) 971 | compareEncoded(t, &UintFieldIndex{"Foo"}, tc.u32l, tc.u32r, tc.expected) 972 | compareEncoded(t, &UintFieldIndex{"Foo"}, tc.u64l, tc.u64r, tc.expected) 973 | compareEncoded(t, &UintFieldIndex{"Foo"}, tc.ul, tc.ur, tc.expected) 974 | }) 975 | } 976 | } 977 | 978 | func compareEncoded(t *testing.T, indexer Indexer, l interface{}, r interface{}, expected int) { 979 | lBytes, err := indexer.FromArgs(l) 980 | if err != nil { 981 | t.Fatalf("unable to encode %d: %s", l, err) 982 | } 983 | rBytes, err := indexer.FromArgs(r) 984 | if err != nil { 985 | t.Fatalf("unable to encode %d: %s", r, err) 986 | } 987 | 988 | if bytes.Compare(lBytes, rBytes) != expected { 989 | t.Fatalf("Compare(%#v, %#v) != %d", lBytes, rBytes, expected) 990 | } 991 | } 992 | 993 | func TestBoolFieldIndex_FromObject(t *testing.T) { 994 | obj := testObj() 995 | indexer := BoolFieldIndex{Field: "Bool"} 996 | 997 | obj.Bool = false 998 | ok, val, err := indexer.FromObject(obj) 999 | if err != nil { 1000 | t.Fatalf("err: %v", err) 1001 | } 1002 | if !ok { 1003 | t.Fatalf("should be ok") 1004 | } 1005 | if len(val) != 1 || val[0] != 0 { 1006 | t.Fatalf("bad: %v", val) 1007 | } 1008 | 1009 | obj.Bool = true 1010 | ok, val, err = indexer.FromObject(obj) 1011 | if err != nil { 1012 | t.Fatalf("err: %v", err) 1013 | } 1014 | if !ok { 1015 | t.Fatalf("should be ok") 1016 | } 1017 | if len(val) != 1 || val[0] != 1 { 1018 | t.Fatalf("bad: %v", val) 1019 | } 1020 | 1021 | indexer = BoolFieldIndex{Field: "NA"} 1022 | ok, val, err = indexer.FromObject(obj) 1023 | if err == nil { 1024 | t.Fatalf("should get error") 1025 | } 1026 | 1027 | indexer = BoolFieldIndex{Field: "ID"} 1028 | ok, val, err = indexer.FromObject(obj) 1029 | if err == nil { 1030 | t.Fatalf("should get error") 1031 | } 1032 | } 1033 | 1034 | func TestBoolFieldIndex_FromArgs(t *testing.T) { 1035 | indexer := BoolFieldIndex{Field: "Bool"} 1036 | 1037 | val, err := indexer.FromArgs() 1038 | if err == nil { 1039 | t.Fatalf("should get err") 1040 | } 1041 | 1042 | val, err = indexer.FromArgs(42) 1043 | if err == nil { 1044 | t.Fatalf("should get err") 1045 | } 1046 | 1047 | val, err = indexer.FromArgs(true) 1048 | if err != nil { 1049 | t.Fatalf("err: %v", err) 1050 | } 1051 | if len(val) != 1 || val[0] != 1 { 1052 | t.Fatalf("bad: %v", val) 1053 | } 1054 | 1055 | val, err = indexer.FromArgs(false) 1056 | if err != nil { 1057 | t.Fatalf("err: %v", err) 1058 | } 1059 | if len(val) != 1 || val[0] != 0 { 1060 | t.Fatalf("bad: %v", val) 1061 | } 1062 | } 1063 | 1064 | func TestFieldSetIndex_FromObject(t *testing.T) { 1065 | obj := testObj() 1066 | indexer := FieldSetIndex{"Bam"} 1067 | 1068 | ok, val, err := indexer.FromObject(obj) 1069 | if err != nil { 1070 | t.Fatalf("err: %v", err) 1071 | } 1072 | if !ok { 1073 | t.Fatalf("should be ok") 1074 | } 1075 | if len(val) != 1 || val[0] != 1 { 1076 | t.Fatalf("bad: %v", val) 1077 | } 1078 | 1079 | emptyIndexer := FieldSetIndex{"Empty"} 1080 | ok, val, err = emptyIndexer.FromObject(obj) 1081 | if err != nil { 1082 | t.Fatalf("err: %v", err) 1083 | } 1084 | if !ok { 1085 | t.Fatalf("should be ok") 1086 | } 1087 | if len(val) != 1 || val[0] != 0 { 1088 | t.Fatalf("bad: %v", val) 1089 | } 1090 | 1091 | setIndexer := FieldSetIndex{"Bar"} 1092 | ok, val, err = setIndexer.FromObject(obj) 1093 | if err != nil { 1094 | t.Fatalf("err: %v", err) 1095 | } 1096 | if len(val) != 1 || val[0] != 1 { 1097 | t.Fatalf("bad: %v", val) 1098 | } 1099 | if !ok { 1100 | t.Fatalf("should be ok") 1101 | } 1102 | 1103 | badField := FieldSetIndex{"NA"} 1104 | ok, val, err = badField.FromObject(obj) 1105 | if err == nil { 1106 | t.Fatalf("should get error") 1107 | } 1108 | 1109 | obj.Bam = nil 1110 | nilIndexer := FieldSetIndex{"Bam"} 1111 | ok, val, err = nilIndexer.FromObject(obj) 1112 | if err != nil { 1113 | t.Fatalf("err: %v", err) 1114 | } 1115 | if !ok { 1116 | t.Fatalf("should be ok") 1117 | } 1118 | if len(val) != 1 || val[0] != 0 { 1119 | t.Fatalf("bad: %v", val) 1120 | } 1121 | } 1122 | 1123 | func TestFieldSetIndex_FromArgs(t *testing.T) { 1124 | indexer := FieldSetIndex{"Bam"} 1125 | _, err := indexer.FromArgs() 1126 | if err == nil { 1127 | t.Fatalf("should get err") 1128 | } 1129 | 1130 | _, err = indexer.FromArgs(42) 1131 | if err == nil { 1132 | t.Fatalf("should get err") 1133 | } 1134 | 1135 | val, err := indexer.FromArgs(true) 1136 | if err != nil { 1137 | t.Fatalf("err: %v", err) 1138 | } 1139 | if len(val) != 1 || val[0] != 1 { 1140 | t.Fatalf("bad: %v", val) 1141 | } 1142 | 1143 | val, err = indexer.FromArgs(false) 1144 | if err != nil { 1145 | t.Fatalf("err: %v", err) 1146 | } 1147 | if len(val) != 1 || val[0] != 0 { 1148 | t.Fatalf("bad: %v", val) 1149 | } 1150 | } 1151 | 1152 | // A conditional that checks if TestObject.Bar == 42 1153 | var conditional = func(obj interface{}) (bool, error) { 1154 | test, ok := obj.(*TestObject) 1155 | if !ok { 1156 | return false, fmt.Errorf("Expect only TestObj types") 1157 | } 1158 | 1159 | if test.Bar != 42 { 1160 | return false, nil 1161 | } 1162 | 1163 | return true, nil 1164 | } 1165 | 1166 | func TestConditionalIndex_FromObject(t *testing.T) { 1167 | obj := testObj() 1168 | indexer := ConditionalIndex{conditional} 1169 | obj.Bar = 42 1170 | ok, val, err := indexer.FromObject(obj) 1171 | if err != nil { 1172 | t.Fatalf("err: %v", err) 1173 | } 1174 | if !ok { 1175 | t.Fatalf("should be ok") 1176 | } 1177 | if len(val) != 1 || val[0] != 1 { 1178 | t.Fatalf("bad: %v", val) 1179 | } 1180 | 1181 | // Change the object so it should return false. 1182 | obj.Bar = 2 1183 | ok, val, err = indexer.FromObject(obj) 1184 | if err != nil { 1185 | t.Fatalf("err: %v", err) 1186 | } 1187 | if !ok { 1188 | t.Fatalf("should be ok") 1189 | } 1190 | if len(val) != 1 || val[0] != 0 { 1191 | t.Fatalf("bad: %v", val) 1192 | } 1193 | 1194 | // Pass an invalid type. 1195 | ok, val, err = indexer.FromObject(t) 1196 | if err == nil { 1197 | t.Fatalf("expected an error when passing invalid type") 1198 | } 1199 | } 1200 | 1201 | func TestConditionalIndex_FromArgs(t *testing.T) { 1202 | indexer := ConditionalIndex{conditional} 1203 | _, err := indexer.FromArgs() 1204 | if err == nil { 1205 | t.Fatalf("should get err") 1206 | } 1207 | 1208 | _, err = indexer.FromArgs(42) 1209 | if err == nil { 1210 | t.Fatalf("should get err") 1211 | } 1212 | 1213 | val, err := indexer.FromArgs(true) 1214 | if err != nil { 1215 | t.Fatalf("err: %v", err) 1216 | } 1217 | if len(val) != 1 || val[0] != 1 { 1218 | t.Fatalf("bad: %v", val) 1219 | } 1220 | 1221 | val, err = indexer.FromArgs(false) 1222 | if err != nil { 1223 | t.Fatalf("err: %v", err) 1224 | } 1225 | if len(val) != 1 || val[0] != 0 { 1226 | t.Fatalf("bad: %v", val) 1227 | } 1228 | } 1229 | 1230 | func TestCompoundIndex_FromObject(t *testing.T) { 1231 | obj := testObj() 1232 | indexer := &CompoundIndex{ 1233 | Indexes: []Indexer{ 1234 | &StringFieldIndex{"ID", false}, 1235 | &StringFieldIndex{"Foo", false}, 1236 | &StringFieldIndex{"Baz", false}, 1237 | }, 1238 | AllowMissing: false, 1239 | } 1240 | 1241 | ok, val, err := indexer.FromObject(obj) 1242 | if err != nil { 1243 | t.Fatalf("err: %v", err) 1244 | } 1245 | if string(val) != "my-cool-obj\x00Testing\x00yep\x00" { 1246 | t.Fatalf("bad: %s", val) 1247 | } 1248 | if !ok { 1249 | t.Fatalf("should be ok") 1250 | } 1251 | 1252 | missing := &CompoundIndex{ 1253 | Indexes: []Indexer{ 1254 | &StringFieldIndex{"ID", false}, 1255 | &StringFieldIndex{"Foo", true}, 1256 | &StringFieldIndex{"Empty", false}, 1257 | }, 1258 | AllowMissing: true, 1259 | } 1260 | ok, val, err = missing.FromObject(obj) 1261 | if err != nil { 1262 | t.Fatalf("err: %v", err) 1263 | } 1264 | if string(val) != "my-cool-obj\x00testing\x00" { 1265 | t.Fatalf("bad: %s", val) 1266 | } 1267 | if !ok { 1268 | t.Fatalf("should be ok") 1269 | } 1270 | 1271 | // Test when missing not allowed 1272 | missing.AllowMissing = false 1273 | ok, _, err = missing.FromObject(obj) 1274 | if err != nil { 1275 | t.Fatalf("err: %v", err) 1276 | } 1277 | if ok { 1278 | t.Fatalf("should not be okay") 1279 | } 1280 | } 1281 | 1282 | func TestCompoundIndex_FromArgs(t *testing.T) { 1283 | indexer := &CompoundIndex{ 1284 | Indexes: []Indexer{ 1285 | &StringFieldIndex{"ID", false}, 1286 | &StringFieldIndex{"Foo", false}, 1287 | &StringFieldIndex{"Baz", false}, 1288 | }, 1289 | AllowMissing: false, 1290 | } 1291 | _, err := indexer.FromArgs() 1292 | if err == nil { 1293 | t.Fatalf("should get err") 1294 | } 1295 | 1296 | _, err = indexer.FromArgs(42, 42, 42) 1297 | if err == nil { 1298 | t.Fatalf("should get err") 1299 | } 1300 | 1301 | val, err := indexer.FromArgs("foo", "bar", "baz") 1302 | if err != nil { 1303 | t.Fatalf("err: %v", err) 1304 | } 1305 | if string(val) != "foo\x00bar\x00baz\x00" { 1306 | t.Fatalf("bad: %s", val) 1307 | } 1308 | } 1309 | 1310 | func TestCompoundIndex_PrefixFromArgs(t *testing.T) { 1311 | indexer := &CompoundIndex{ 1312 | Indexes: []Indexer{ 1313 | &UUIDFieldIndex{"ID"}, 1314 | &StringFieldIndex{"Foo", false}, 1315 | &StringFieldIndex{"Baz", false}, 1316 | }, 1317 | AllowMissing: false, 1318 | } 1319 | val, err := indexer.PrefixFromArgs() 1320 | if err != nil { 1321 | t.Fatalf("err: %v", err) 1322 | } 1323 | if len(val) != 0 { 1324 | t.Fatalf("bad: %s", val) 1325 | } 1326 | 1327 | uuidBuf, uuid := generateUUID() 1328 | val, err = indexer.PrefixFromArgs(uuid, "foo") 1329 | if err != nil { 1330 | t.Fatalf("err: %v", err) 1331 | } 1332 | if !bytes.Equal(val[:16], uuidBuf) { 1333 | t.Fatalf("bad prefix") 1334 | } 1335 | if string(val[16:]) != "foo" { 1336 | t.Fatalf("bad: %s", val) 1337 | } 1338 | 1339 | val, err = indexer.PrefixFromArgs(uuid, "foo", "ba") 1340 | if err != nil { 1341 | t.Fatalf("err: %v", err) 1342 | } 1343 | if !bytes.Equal(val[:16], uuidBuf) { 1344 | t.Fatalf("bad prefix") 1345 | } 1346 | if string(val[16:]) != "foo\x00ba" { 1347 | t.Fatalf("bad: %s", val) 1348 | } 1349 | 1350 | _, err = indexer.PrefixFromArgs(uuid, "foo", "bar", "nope") 1351 | if err == nil { 1352 | t.Fatalf("expected an error when passing too many arguments") 1353 | } 1354 | } 1355 | 1356 | func TestCompoundMultiIndex_FromObject(t *testing.T) { 1357 | // handle sub-indexer case unique to MultiIndexer 1358 | obj := &TestObject{ 1359 | ID: "obj1-uuid", 1360 | Foo: "Foo1", 1361 | Baz: "yep", 1362 | Qux: []string{"Test", "Test2"}, 1363 | QuxEmpty: []string{"Qux", "Qux2"}, 1364 | } 1365 | indexer := &CompoundMultiIndex{ 1366 | Indexes: []Indexer{ 1367 | &StringFieldIndex{Field: "Foo"}, 1368 | &StringSliceFieldIndex{Field: "Qux"}, 1369 | &StringSliceFieldIndex{Field: "QuxEmpty"}, 1370 | }, 1371 | } 1372 | 1373 | ok, vals, err := indexer.FromObject(obj) 1374 | if err != nil { 1375 | t.Fatalf("err: %v", err) 1376 | } 1377 | if !ok { 1378 | t.Fatalf("should be ok") 1379 | } 1380 | want := []string{ 1381 | "Foo1\x00Test\x00Qux\x00", 1382 | "Foo1\x00Test\x00Qux2\x00", 1383 | "Foo1\x00Test2\x00Qux\x00", 1384 | "Foo1\x00Test2\x00Qux2\x00", 1385 | } 1386 | got := make([]string, len(vals)) 1387 | for i, v := range vals { 1388 | got[i] = string(v) 1389 | } 1390 | if !reflect.DeepEqual(got, want) { 1391 | t.Fatalf("\ngot: %+v\nwant: %+v\n", got, want) 1392 | } 1393 | } 1394 | 1395 | func TestCompoundMultiIndex_FromObject_IndexUniquenessProperty(t *testing.T) { 1396 | indexPermutations := [][]string{ 1397 | {"Foo", "Qux", "QuxEmpty"}, 1398 | {"Foo", "QuxEmpty", "Qux"}, 1399 | {"QuxEmpty", "Qux", "Foo"}, 1400 | {"QuxEmpty", "Foo", "Qux"}, 1401 | {"Qux", "QuxEmpty", "Foo"}, 1402 | {"Qux", "Foo", "QuxEmpty"}, 1403 | } 1404 | 1405 | fn := func(o TestObject) bool { 1406 | for _, perm := range indexPermutations { 1407 | indexer := indexerFromFieldNameList(perm) 1408 | ok, vals, err := indexer.FromObject(o) 1409 | if err != nil { 1410 | t.Logf("err: %v", err) 1411 | return false 1412 | } 1413 | if !ok { 1414 | t.Logf("should be ok") 1415 | return false 1416 | } 1417 | if !assertAllUnique(t, vals) { 1418 | return false 1419 | } 1420 | } 1421 | return true 1422 | } 1423 | seed := time.Now().UnixNano() 1424 | t.Logf("Using seed %v", seed) 1425 | cfg := quick.Config{Rand: rand.New(rand.NewSource(seed))} 1426 | if err := quick.Check(fn, &cfg); err != nil { 1427 | t.Fatalf("property not held: %v", err) 1428 | } 1429 | } 1430 | 1431 | func assertAllUnique(t *testing.T, vals [][]byte) bool { 1432 | t.Helper() 1433 | s := make(map[string]struct{}, len(vals)) 1434 | for _, index := range vals { 1435 | s[string(index)] = struct{}{} 1436 | } 1437 | 1438 | if l := len(s); l != len(vals) { 1439 | t.Logf("expected %d unique indexes, got %v", len(vals), l) 1440 | return false 1441 | } 1442 | return true 1443 | } 1444 | 1445 | func indexerFromFieldNameList(keys []string) *CompoundMultiIndex { 1446 | indexer := &CompoundMultiIndex{AllowMissing: true} 1447 | for _, key := range keys { 1448 | if key == "Foo" || key == "Baz" { 1449 | indexer.Indexes = append(indexer.Indexes, &StringFieldIndex{Field: key}) 1450 | continue 1451 | } 1452 | indexer.Indexes = append(indexer.Indexes, &StringSliceFieldIndex{Field: key}) 1453 | } 1454 | return indexer 1455 | } 1456 | 1457 | func BenchmarkCompoundMultiIndex_FromObject(b *testing.B) { 1458 | obj := &TestObject{ 1459 | ID: "obj1-uuid", 1460 | Foo: "Foo1", 1461 | Baz: "yep", 1462 | Qux: []string{"Test", "Test2"}, 1463 | QuxEmpty: []string{"Qux", "Qux2"}, 1464 | } 1465 | indexer := &CompoundMultiIndex{ 1466 | Indexes: []Indexer{ 1467 | &StringFieldIndex{Field: "Foo"}, 1468 | &StringSliceFieldIndex{Field: "Qux"}, 1469 | &StringSliceFieldIndex{Field: "QuxEmpty"}, 1470 | }, 1471 | } 1472 | 1473 | b.ResetTimer() 1474 | for i := 0; i < b.N; i++ { 1475 | ok, vals, err := indexer.FromObject(obj) 1476 | if err != nil { 1477 | b.Fatalf("expected no error, got: %v", err) 1478 | } 1479 | if !ok { 1480 | b.Fatalf("should be ok") 1481 | } 1482 | if l := len(vals); l != 4 { 1483 | b.Fatalf("expected 4 indexes, got %v", l) 1484 | } 1485 | } 1486 | } 1487 | -------------------------------------------------------------------------------- /integ_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package memdb 5 | 6 | import ( 7 | "reflect" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | // Test that multiple concurrent transactions are isolated from each other 13 | func TestTxn_Isolation(t *testing.T) { 14 | db := testDB(t) 15 | txn1 := db.Txn(true) 16 | 17 | obj := &TestObject{ 18 | ID: "my-object", 19 | Foo: "abc", 20 | Qux: []string{"abc1", "abc2"}, 21 | } 22 | obj2 := &TestObject{ 23 | ID: "my-cool-thing", 24 | Foo: "xyz", 25 | Qux: []string{"xyz1", "xyz2"}, 26 | } 27 | obj3 := &TestObject{ 28 | ID: "my-other-cool-thing", 29 | Foo: "xyz", 30 | Qux: []string{"xyz1", "xyz2"}, 31 | } 32 | 33 | err := txn1.Insert("main", obj) 34 | if err != nil { 35 | t.Fatalf("err: %v", err) 36 | } 37 | err = txn1.Insert("main", obj2) 38 | if err != nil { 39 | t.Fatalf("err: %v", err) 40 | } 41 | err = txn1.Insert("main", obj3) 42 | if err != nil { 43 | t.Fatalf("err: %v", err) 44 | } 45 | 46 | // Results should show up in this transaction 47 | raw, err := txn1.First("main", "id") 48 | if err != nil { 49 | t.Fatalf("err: %v", err) 50 | } 51 | if raw == nil { 52 | t.Fatalf("bad: %#v", raw) 53 | } 54 | 55 | // Create a new transaction, current one is NOT committed 56 | txn2 := db.Txn(false) 57 | 58 | // Nothing should show up in this transaction 59 | raw, err = txn2.First("main", "id") 60 | if err != nil { 61 | t.Fatalf("err: %v", err) 62 | } 63 | if raw != nil { 64 | t.Fatalf("bad: %#v", raw) 65 | } 66 | 67 | // Commit txn1, txn2 should still be isolated 68 | txn1.Commit() 69 | 70 | // Nothing should show up in this transaction 71 | raw, err = txn2.First("main", "id") 72 | if err != nil { 73 | t.Fatalf("err: %v", err) 74 | } 75 | if raw != nil { 76 | t.Fatalf("bad: %#v", raw) 77 | } 78 | 79 | // Create a new txn 80 | txn3 := db.Txn(false) 81 | 82 | // Results should show up in this transaction 83 | raw, err = txn3.First("main", "id") 84 | if err != nil { 85 | t.Fatalf("err: %v", err) 86 | } 87 | if raw == nil { 88 | t.Fatalf("bad: %#v", raw) 89 | } 90 | } 91 | 92 | // Test that an abort clears progress 93 | func TestTxn_Abort(t *testing.T) { 94 | db := testDB(t) 95 | txn1 := db.Txn(true) 96 | 97 | obj := &TestObject{ 98 | ID: "my-object", 99 | Foo: "abc", 100 | Qux: []string{"abc1", "abc2"}, 101 | } 102 | obj2 := &TestObject{ 103 | ID: "my-cool-thing", 104 | Foo: "xyz", 105 | Qux: []string{"xyz1", "xyz2"}, 106 | } 107 | obj3 := &TestObject{ 108 | ID: "my-other-cool-thing", 109 | Foo: "xyz", 110 | Qux: []string{"xyz1", "xyz2"}, 111 | } 112 | 113 | err := txn1.Insert("main", obj) 114 | if err != nil { 115 | t.Fatalf("err: %v", err) 116 | } 117 | err = txn1.Insert("main", obj2) 118 | if err != nil { 119 | t.Fatalf("err: %v", err) 120 | } 121 | err = txn1.Insert("main", obj3) 122 | if err != nil { 123 | t.Fatalf("err: %v", err) 124 | } 125 | 126 | // Abort the txn 127 | txn1.Abort() 128 | txn1.Commit() 129 | 130 | // Create a new transaction 131 | txn2 := db.Txn(false) 132 | 133 | // Nothing should show up in this transaction 134 | raw, err := txn2.First("main", "id") 135 | if err != nil { 136 | t.Fatalf("err: %v", err) 137 | } 138 | if raw != nil { 139 | t.Fatalf("bad: %#v", raw) 140 | } 141 | } 142 | 143 | func TestComplexDB(t *testing.T) { 144 | db := testComplexDB(t) 145 | testPopulateData(t, db) 146 | txn := db.Txn(false) // read only 147 | 148 | // Get using a full name 149 | raw, err := txn.First("people", "name", "Armon", "Dadgar") 150 | noErr(t, err) 151 | if raw == nil { 152 | t.Fatalf("should get person") 153 | } 154 | 155 | // Get using a prefix 156 | raw, err = txn.First("people", "name_prefix", "Armon") 157 | noErr(t, err) 158 | if raw == nil { 159 | t.Fatalf("should get person") 160 | } 161 | 162 | raw, err = txn.First("people", "id_prefix", raw.(*TestPerson).ID[:4]) 163 | noErr(t, err) 164 | if raw == nil { 165 | t.Fatalf("should get person") 166 | } 167 | 168 | // Get based on field set. 169 | result, err := txn.Get("people", "sibling", true) 170 | noErr(t, err) 171 | if raw == nil { 172 | t.Fatalf("should get person") 173 | } 174 | 175 | exp := map[string]bool{"Alex": true, "Armon": true} 176 | act := make(map[string]bool, 2) 177 | for i := result.Next(); i != nil; i = result.Next() { 178 | p, ok := i.(*TestPerson) 179 | if !ok { 180 | t.Fatalf("should get person") 181 | } 182 | act[p.First] = true 183 | } 184 | 185 | if !reflect.DeepEqual(act, exp) { 186 | t.Fatalf("Got %#v; want %#v", act, exp) 187 | } 188 | 189 | raw, err = txn.First("people", "sibling", false) 190 | noErr(t, err) 191 | if raw == nil { 192 | t.Fatalf("should get person") 193 | } 194 | if raw.(*TestPerson).First != "Mitchell" { 195 | t.Fatalf("wrong person!") 196 | } 197 | 198 | raw, err = txn.First("people", "age", uint8(23)) 199 | noErr(t, err) 200 | if raw == nil { 201 | t.Fatalf("should get person") 202 | } 203 | 204 | raw, err = txn.First("people", "negative_age", int8(-23)) 205 | noErr(t, err) 206 | if raw == nil { 207 | t.Fatalf("should get person") 208 | } 209 | 210 | person := raw.(*TestPerson) 211 | if person.First != "Alex" { 212 | t.Fatalf("wrong person!") 213 | } 214 | 215 | // Where in the world is mitchell hashimoto? 216 | raw, err = txn.First("people", "name_prefix", "Mitchell") 217 | noErr(t, err) 218 | if raw == nil { 219 | t.Fatalf("should get person") 220 | } 221 | 222 | person = raw.(*TestPerson) 223 | if person.First != "Mitchell" { 224 | t.Fatalf("wrong person!") 225 | } 226 | 227 | raw, err = txn.First("visits", "id_prefix", person.ID) 228 | noErr(t, err) 229 | if raw == nil { 230 | t.Fatalf("should get visit") 231 | } 232 | 233 | visit := raw.(*TestVisit) 234 | 235 | raw, err = txn.First("places", "id", visit.Place) 236 | noErr(t, err) 237 | if raw == nil { 238 | t.Fatalf("should get place") 239 | } 240 | 241 | place := raw.(*TestPlace) 242 | if place.Name != "Maui" { 243 | t.Fatalf("bad place (but isn't anywhere else really?): %v", place) 244 | } 245 | 246 | raw, err = txn.First("places", "name_tags", "HashiCorp", "North America") 247 | noErr(t, err) 248 | if raw == nil { 249 | t.Fatalf("should get place") 250 | } 251 | place = raw.(*TestPlace) 252 | if place.Name != "HashiCorp" { 253 | t.Fatalf("bad place (but isn't anywhere else really?): %v", place) 254 | } 255 | 256 | raw, err = txn.First("places", "name_tags", "Maui") 257 | noErr(t, err) 258 | if raw == nil { 259 | t.Fatalf("should get place") 260 | } 261 | place = raw.(*TestPlace) 262 | if place.Name != "Maui" { 263 | t.Fatalf("bad place (but isn't anywhere else really?): %v", place) 264 | } 265 | 266 | raw, err = txn.First("places", "name_tags_name_meta", "HashiCorp", "North America", "HashiCorp", "Food", "Pretty Good") 267 | noErr(t, err) 268 | if raw == nil { 269 | t.Fatalf("should get place") 270 | } 271 | place = raw.(*TestPlace) 272 | if place.Tags[1] != "USA" { 273 | t.Fatalf("bad place: %v", place) 274 | } 275 | 276 | raw, err = txn.First("places", "name_tags_name_meta", "HashiCorp", "North America", "HashiCorp", "Piers", "Pretty Salty") 277 | noErr(t, err) 278 | if raw == nil { 279 | t.Fatalf("should get place") 280 | } 281 | place = raw.(*TestPlace) 282 | if place.Tags[1] != "Earth" { 283 | t.Fatalf("bad place: %v", place) 284 | } 285 | } 286 | 287 | func TestWatchUpdate(t *testing.T) { 288 | db := testComplexDB(t) 289 | testPopulateData(t, db) 290 | txn := db.Txn(false) // read only 291 | 292 | watchSetIter := NewWatchSet() 293 | watchSetSpecific := NewWatchSet() 294 | watchSetPrefix := NewWatchSet() 295 | 296 | // Get using an iterator. 297 | iter, err := txn.Get("people", "name", "Armon", "Dadgar") 298 | noErr(t, err) 299 | watchSetIter.Add(iter.WatchCh()) 300 | if raw := iter.Next(); raw == nil { 301 | t.Fatalf("should get person") 302 | } 303 | 304 | // Get using a full name. 305 | watch, raw, err := txn.FirstWatch("people", "name", "Armon", "Dadgar") 306 | noErr(t, err) 307 | if raw == nil { 308 | t.Fatalf("should get person") 309 | } 310 | watchSetSpecific.Add(watch) 311 | 312 | // Get using a prefix. 313 | watch, raw, err = txn.FirstWatch("people", "name_prefix", "Armon") 314 | noErr(t, err) 315 | if raw == nil { 316 | t.Fatalf("should get person") 317 | } 318 | watchSetPrefix.Add(watch) 319 | 320 | // Write to a snapshot. 321 | snap := db.Snapshot() 322 | txn2 := snap.Txn(true) // write 323 | noErr(t, txn2.Delete("people", raw)) 324 | txn2.Commit() 325 | 326 | // None of the watches should trigger since we didn't alter the 327 | // primary. 328 | wait := 100 * time.Millisecond 329 | if timeout := watchSetIter.Watch(time.After(wait)); !timeout { 330 | t.Fatalf("should timeout") 331 | } 332 | if timeout := watchSetSpecific.Watch(time.After(wait)); !timeout { 333 | t.Fatalf("should timeout") 334 | } 335 | if timeout := watchSetPrefix.Watch(time.After(wait)); !timeout { 336 | t.Fatalf("should timeout") 337 | } 338 | 339 | // Write to the primary. 340 | txn3 := db.Txn(true) // write 341 | noErr(t, txn3.Delete("people", raw)) 342 | txn3.Commit() 343 | 344 | // All three watches should trigger! 345 | wait = time.Second 346 | if timeout := watchSetIter.Watch(time.After(wait)); timeout { 347 | t.Fatalf("should not timeout") 348 | } 349 | if timeout := watchSetSpecific.Watch(time.After(wait)); timeout { 350 | t.Fatalf("should not timeout") 351 | } 352 | if timeout := watchSetPrefix.Watch(time.After(wait)); timeout { 353 | t.Fatalf("should not timeout") 354 | } 355 | } 356 | 357 | func testPopulateData(t *testing.T, db *MemDB) { 358 | // Start write txn 359 | txn := db.Txn(true) 360 | 361 | // Create some data 362 | person1 := testPerson() 363 | 364 | person2 := testPerson() 365 | person2.First = "Mitchell" 366 | person2.Last = "Hashimoto" 367 | person2.Age = 27 368 | person2.NegativeAge = -27 369 | 370 | person3 := testPerson() 371 | person3.First = "Alex" 372 | person3.Last = "Dadgar" 373 | person3.Age = 23 374 | person3.NegativeAge = -23 375 | 376 | person1.Sibling = person3 377 | person3.Sibling = person1 378 | 379 | place1 := testPlace() 380 | place1.Tags = []string{"North America", "USA"} 381 | place1.Meta = map[string]string{"Food": "Pretty Good"} 382 | place2 := testPlace() 383 | place2.Name = "Maui" 384 | place3 := testPlace() 385 | place3.Tags = []string{"North America", "Earth"} 386 | place3.Meta = map[string]string{"Piers": "Pretty Salty"} 387 | 388 | visit1 := &TestVisit{person1.ID, place1.ID} 389 | visit2 := &TestVisit{person2.ID, place2.ID} 390 | 391 | // Insert it all 392 | noErr(t, txn.Insert("people", person1)) 393 | noErr(t, txn.Insert("people", person2)) 394 | noErr(t, txn.Insert("people", person3)) 395 | noErr(t, txn.Insert("places", place1)) 396 | noErr(t, txn.Insert("places", place2)) 397 | noErr(t, txn.Insert("places", place3)) 398 | noErr(t, txn.Insert("visits", visit1)) 399 | noErr(t, txn.Insert("visits", visit2)) 400 | 401 | // Commit 402 | txn.Commit() 403 | } 404 | 405 | func expectErr(t *testing.T, err error) { 406 | t.Helper() 407 | if err == nil { 408 | t.Fatal("expected error") 409 | } 410 | } 411 | 412 | func noErr(t *testing.T, err error) { 413 | t.Helper() 414 | if err != nil { 415 | t.Fatalf("err: %v", err) 416 | } 417 | } 418 | 419 | type TestPerson struct { 420 | ID string 421 | First string 422 | Last string 423 | Age uint8 424 | NegativeAge int8 425 | Sibling *TestPerson 426 | } 427 | 428 | type TestPlace struct { 429 | ID string 430 | Name string 431 | Tags []string 432 | Meta map[string]string 433 | } 434 | 435 | type TestVisit struct { 436 | Person string 437 | Place string 438 | } 439 | 440 | func testComplexSchema() *DBSchema { 441 | return &DBSchema{ 442 | Tables: map[string]*TableSchema{ 443 | "people": &TableSchema{ 444 | Name: "people", 445 | Indexes: map[string]*IndexSchema{ 446 | "id": &IndexSchema{ 447 | Name: "id", 448 | Unique: true, 449 | Indexer: &UUIDFieldIndex{Field: "ID"}, 450 | }, 451 | "name": &IndexSchema{ 452 | Name: "name", 453 | Unique: true, 454 | Indexer: &CompoundIndex{ 455 | Indexes: []Indexer{ 456 | &StringFieldIndex{Field: "First"}, 457 | &StringFieldIndex{Field: "Last"}, 458 | }, 459 | }, 460 | }, 461 | "age": &IndexSchema{ 462 | Name: "age", 463 | Unique: false, 464 | Indexer: &UintFieldIndex{Field: "Age"}, 465 | }, 466 | "negative_age": &IndexSchema{ 467 | Name: "negative_age", 468 | Unique: false, 469 | Indexer: &IntFieldIndex{Field: "NegativeAge"}, 470 | }, 471 | "sibling": &IndexSchema{ 472 | Name: "sibling", 473 | Unique: false, 474 | Indexer: &FieldSetIndex{Field: "Sibling"}, 475 | }, 476 | }, 477 | }, 478 | "places": &TableSchema{ 479 | Name: "places", 480 | Indexes: map[string]*IndexSchema{ 481 | "id": &IndexSchema{ 482 | Name: "id", 483 | Unique: true, 484 | Indexer: &UUIDFieldIndex{Field: "ID"}, 485 | }, 486 | "name": &IndexSchema{ 487 | Name: "name", 488 | Unique: true, 489 | Indexer: &StringFieldIndex{Field: "Name"}, 490 | }, 491 | "name_tags": &IndexSchema{ 492 | Name: "name_tags", 493 | Unique: true, 494 | AllowMissing: true, 495 | Indexer: &CompoundMultiIndex{ 496 | AllowMissing: true, 497 | Indexes: []Indexer{ 498 | &StringFieldIndex{Field: "Name"}, 499 | &StringSliceFieldIndex{Field: "Tags"}, 500 | }, 501 | }, 502 | }, 503 | "name_tags_name_meta": &IndexSchema{ 504 | Name: "name_tags_name_meta", 505 | Unique: true, 506 | AllowMissing: true, 507 | Indexer: &CompoundMultiIndex{ 508 | AllowMissing: true, 509 | Indexes: []Indexer{ 510 | &StringFieldIndex{Field: "Name"}, 511 | &StringSliceFieldIndex{Field: "Tags"}, 512 | &StringFieldIndex{Field: "Name"}, 513 | &StringMapFieldIndex{Field: "Meta"}, 514 | }, 515 | }, 516 | }, 517 | }, 518 | }, 519 | "visits": &TableSchema{ 520 | Name: "visits", 521 | Indexes: map[string]*IndexSchema{ 522 | "id": &IndexSchema{ 523 | Name: "id", 524 | Unique: true, 525 | Indexer: &CompoundIndex{ 526 | Indexes: []Indexer{ 527 | &UUIDFieldIndex{Field: "Person"}, 528 | &UUIDFieldIndex{Field: "Place"}, 529 | }, 530 | }, 531 | }, 532 | }, 533 | }, 534 | }, 535 | } 536 | } 537 | 538 | func testComplexDB(t *testing.T) *MemDB { 539 | db, err := NewMemDB(testComplexSchema()) 540 | if err != nil { 541 | t.Fatalf("err: %v", err) 542 | } 543 | return db 544 | } 545 | 546 | func testPerson() *TestPerson { 547 | _, uuid := generateUUID() 548 | obj := &TestPerson{ 549 | ID: uuid, 550 | First: "Armon", 551 | Last: "Dadgar", 552 | Age: 26, 553 | NegativeAge: -26, 554 | } 555 | return obj 556 | } 557 | 558 | func testPlace() *TestPlace { 559 | _, uuid := generateUUID() 560 | obj := &TestPlace{ 561 | ID: uuid, 562 | Name: "HashiCorp", 563 | } 564 | return obj 565 | } 566 | -------------------------------------------------------------------------------- /isolation_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package memdb 5 | 6 | import ( 7 | "testing" 8 | ) 9 | 10 | func TestMemDB_Isolation(t *testing.T) { 11 | 12 | id1 := "object-one" 13 | id2 := "object-two" 14 | id3 := "object-three" 15 | 16 | mustNoError := func(t *testing.T, err error) { 17 | if err != nil { 18 | t.Fatalf("unexpected test error: %v", err) 19 | } 20 | } 21 | 22 | setup := func(t *testing.T) *MemDB { 23 | t.Helper() 24 | 25 | db, err := NewMemDB(testValidSchema()) 26 | if err != nil { 27 | t.Fatalf("err: %v", err) 28 | } 29 | 30 | // Add two objects (with a gap between their IDs) 31 | obj1a := testObj() 32 | obj1a.ID = id1 33 | txn := db.Txn(true) 34 | mustNoError(t, txn.Insert("main", obj1a)) 35 | 36 | obj3 := testObj() 37 | obj3.ID = id3 38 | mustNoError(t, txn.Insert("main", obj3)) 39 | txn.Commit() 40 | return db 41 | } 42 | 43 | t.Run("snapshot dirty read", func(t *testing.T) { 44 | db := setup(t) 45 | db2 := db.Snapshot() 46 | 47 | // Update an object 48 | obj1b := testObj() 49 | obj1b.ID = id1 50 | txn1 := db.Txn(true) 51 | obj1b.Baz = "nope" 52 | mustNoError(t, txn1.Insert("main", obj1b)) 53 | 54 | // Insert an object 55 | obj2 := testObj() 56 | obj2.ID = id2 57 | mustNoError(t, txn1.Insert("main", obj2)) 58 | 59 | txn2 := db2.Txn(false) 60 | out, err := txn2.First("main", "id", id1) 61 | mustNoError(t, err) 62 | if out == nil { 63 | t.Fatalf("should exist") 64 | } 65 | if out.(*TestObject).Baz == "nope" { 66 | t.Fatalf("read from snapshot should not observe uncommitted update (dirty read)") 67 | } 68 | 69 | out, err = txn2.First("main", "id", id2) 70 | mustNoError(t, err) 71 | if out != nil { 72 | t.Fatalf("read from snapshot should not observe uncommitted insert (dirty read)") 73 | } 74 | 75 | // New snapshot should not observe uncommitted writes 76 | db3 := db.Snapshot() 77 | txn3 := db3.Txn(false) 78 | out, err = txn3.First("main", "id", id1) 79 | mustNoError(t, err) 80 | if out == nil { 81 | t.Fatalf("should exist") 82 | } 83 | if out.(*TestObject).Baz == "nope" { 84 | t.Fatalf("read from new snapshot should not observe uncommitted writes") 85 | } 86 | }) 87 | 88 | t.Run("transaction dirty read", func(t *testing.T) { 89 | db := setup(t) 90 | 91 | // Update an object 92 | obj1b := testObj() 93 | obj1b.ID = id1 94 | txn1 := db.Txn(true) 95 | obj1b.Baz = "nope" 96 | mustNoError(t, txn1.Insert("main", obj1b)) 97 | 98 | // Insert an object 99 | obj2 := testObj() 100 | obj2.ID = id2 101 | mustNoError(t, txn1.Insert("main", obj2)) 102 | 103 | txn2 := db.Txn(false) 104 | out, err := txn2.First("main", "id", id1) 105 | mustNoError(t, err) 106 | if out == nil { 107 | t.Fatalf("should exist") 108 | } 109 | if out.(*TestObject).Baz == "nope" { 110 | t.Fatalf("read from transaction should not observe uncommitted update (dirty read)") 111 | } 112 | 113 | out, err = txn2.First("main", "id", id2) 114 | mustNoError(t, err) 115 | if out != nil { 116 | t.Fatalf("read from transaction should not observe uncommitted insert (dirty read)") 117 | } 118 | }) 119 | 120 | t.Run("snapshot non-repeatable read", func(t *testing.T) { 121 | db := setup(t) 122 | db2 := db.Snapshot() 123 | 124 | // Update an object 125 | obj1b := testObj() 126 | obj1b.ID = id1 127 | txn1 := db.Txn(true) 128 | obj1b.Baz = "nope" 129 | mustNoError(t, txn1.Insert("main", obj1b)) 130 | 131 | // Insert an object 132 | obj2 := testObj() 133 | obj2.ID = id3 134 | mustNoError(t, txn1.Insert("main", obj2)) 135 | 136 | // Commit 137 | txn1.Commit() 138 | 139 | txn2 := db2.Txn(false) 140 | out, err := txn2.First("main", "id", id1) 141 | mustNoError(t, err) 142 | if out == nil { 143 | t.Fatalf("should exist") 144 | } 145 | if out.(*TestObject).Baz == "nope" { 146 | t.Fatalf("read from snapshot should not observe committed write from another transaction (non-repeatable read)") 147 | } 148 | 149 | out, err = txn2.First("main", "id", id2) 150 | mustNoError(t, err) 151 | if out != nil { 152 | t.Fatalf("read from snapshot should not observe committed write from another transaction (non-repeatable read)") 153 | } 154 | 155 | }) 156 | 157 | t.Run("transaction non-repeatable read", func(t *testing.T) { 158 | db := setup(t) 159 | 160 | // Update an object 161 | obj1b := testObj() 162 | obj1b.ID = id1 163 | txn1 := db.Txn(true) 164 | obj1b.Baz = "nope" 165 | mustNoError(t, txn1.Insert("main", obj1b)) 166 | 167 | // Insert an object 168 | obj2 := testObj() 169 | obj2.ID = id3 170 | mustNoError(t, txn1.Insert("main", obj2)) 171 | 172 | txn2 := db.Txn(false) 173 | 174 | // Commit 175 | txn1.Commit() 176 | 177 | out, err := txn2.First("main", "id", id1) 178 | mustNoError(t, err) 179 | if out == nil { 180 | t.Fatalf("should exist") 181 | } 182 | if out.(*TestObject).Baz == "nope" { 183 | t.Fatalf("read from transaction should not observe committed write from another transaction (non-repeatable read)") 184 | } 185 | 186 | out, err = txn2.First("main", "id", id2) 187 | mustNoError(t, err) 188 | if out != nil { 189 | t.Fatalf("read from transaction should not observe committed write from another transaction (non-repeatable read)") 190 | } 191 | 192 | }) 193 | 194 | t.Run("snapshot phantom read", func(t *testing.T) { 195 | db := setup(t) 196 | db2 := db.Snapshot() 197 | 198 | txn2 := db2.Txn(false) 199 | iter, err := txn2.Get("main", "id_prefix", "object") 200 | mustNoError(t, err) 201 | out := iter.Next() 202 | if out == nil || out.(*TestObject).ID != id1 { 203 | t.Fatal("missing expected object 'object-one'") 204 | } 205 | 206 | // Insert an object and commit 207 | txn1 := db.Txn(true) 208 | obj2 := testObj() 209 | obj2.ID = id2 210 | mustNoError(t, txn1.Insert("main", obj2)) 211 | txn1.Commit() 212 | 213 | out = iter.Next() 214 | if out == nil { 215 | t.Fatal("expected 2 objects") 216 | } 217 | if out.(*TestObject).ID == id2 { 218 | t.Fatalf("read from snapshot should not observe new objects in set (phantom read)") 219 | } 220 | 221 | out = iter.Next() 222 | if out != nil { 223 | t.Fatal("expected only 2 objects: read from snapshot should not observe new objects in set (phantom read)") 224 | } 225 | 226 | // Remove an object using an outdated pointer 227 | txn1 = db.Txn(true) 228 | obj1, err := txn1.First("main", "id", id1) 229 | mustNoError(t, err) 230 | mustNoError(t, txn1.Delete("main", obj1)) 231 | txn1.Commit() 232 | 233 | iter, err = txn2.Get("main", "id_prefix", "object") 234 | mustNoError(t, err) 235 | 236 | out = iter.Next() 237 | if out == nil || out.(*TestObject).ID != id1 { 238 | t.Fatal("missing expected object 'object-one': read from snapshot should not observe deletes (phantom read)") 239 | } 240 | out = iter.Next() 241 | if out == nil || out.(*TestObject).ID != id3 { 242 | t.Fatal("missing expected object 'object-three': read from snapshot should not observe deletes (phantom read)") 243 | } 244 | 245 | }) 246 | 247 | t.Run("transaction phantom read", func(t *testing.T) { 248 | db := setup(t) 249 | 250 | txn2 := db.Txn(false) 251 | iter, err := txn2.Get("main", "id_prefix", "object") 252 | mustNoError(t, err) 253 | out := iter.Next() 254 | if out == nil || out.(*TestObject).ID != id1 { 255 | t.Fatal("missing expected object 'object-one'") 256 | } 257 | 258 | // Insert an object and commit 259 | txn1 := db.Txn(true) 260 | obj2 := testObj() 261 | obj2.ID = id2 262 | mustNoError(t, txn1.Insert("main", obj2)) 263 | txn1.Commit() 264 | 265 | out = iter.Next() 266 | if out == nil { 267 | t.Fatal("expected 2 objects") 268 | } 269 | if out.(*TestObject).ID == id2 { 270 | t.Fatalf("read from transaction should not observe new objects in set (phantom read)") 271 | } 272 | 273 | out = iter.Next() 274 | if out != nil { 275 | t.Fatal("expected only 2 objects: read from transaction should not observe new objects in set (phantom read)") 276 | } 277 | 278 | // Remove an object using an outdated pointer 279 | txn1 = db.Txn(true) 280 | obj1, err := txn1.First("main", "id", id1) 281 | mustNoError(t, err) 282 | mustNoError(t, txn1.Delete("main", obj1)) 283 | txn1.Commit() 284 | 285 | iter, err = txn2.Get("main", "id_prefix", "object") 286 | if err != nil { 287 | t.Fatalf("err: %v", err) 288 | } 289 | 290 | out = iter.Next() 291 | if out == nil || out.(*TestObject).ID != id1 { 292 | t.Fatal("missing expected object 'object-one': read from transaction should not observe deletes (phantom read)") 293 | } 294 | out = iter.Next() 295 | if out == nil || out.(*TestObject).ID != id3 { 296 | t.Fatal("missing expected object 'object-three': read from transaction should not observe deletes (phantom read)") 297 | } 298 | 299 | }) 300 | 301 | t.Run("snapshot commits are unobservable", func(t *testing.T) { 302 | db := setup(t) 303 | db2 := db.Snapshot() 304 | 305 | txn2 := db2.Txn(true) 306 | obj1 := testObj() 307 | obj1.ID = id1 308 | obj1.Baz = "also" 309 | mustNoError(t, txn2.Insert("main", obj1)) 310 | txn2.Commit() 311 | 312 | txn1 := db.Txn(false) 313 | out, err := txn1.First("main", "id", id1) 314 | mustNoError(t, err) 315 | if out == nil { 316 | t.Fatalf("should exist") 317 | } 318 | if out.(*TestObject).Baz == "also" { 319 | t.Fatalf("commit from snapshot should never be observed") 320 | } 321 | }) 322 | } 323 | -------------------------------------------------------------------------------- /memdb.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | // Package memdb provides an in-memory database that supports transactions 5 | // and MVCC. 6 | package memdb 7 | 8 | import ( 9 | "sync" 10 | "sync/atomic" 11 | "unsafe" 12 | 13 | "github.com/hashicorp/go-immutable-radix" 14 | ) 15 | 16 | // MemDB is an in-memory database providing Atomicity, Consistency, and 17 | // Isolation from ACID. MemDB doesn't provide Durability since it is an 18 | // in-memory database. 19 | // 20 | // MemDB provides a table abstraction to store objects (rows) with multiple 21 | // indexes based on inserted values. The database makes use of immutable radix 22 | // trees to provide transactions and MVCC. 23 | // 24 | // Objects inserted into MemDB are not copied. It is **extremely important** 25 | // that objects are not modified in-place after they are inserted since they 26 | // are stored directly in MemDB. It remains unsafe to modify inserted objects 27 | // even after they've been deleted from MemDB since there may still be older 28 | // snapshots of the DB being read from other goroutines. 29 | type MemDB struct { 30 | schema *DBSchema 31 | root unsafe.Pointer // *iradix.Tree underneath 32 | primary bool 33 | 34 | // There can only be a single writer at once 35 | writer sync.Mutex 36 | } 37 | 38 | // NewMemDB creates a new MemDB with the given schema. 39 | func NewMemDB(schema *DBSchema) (*MemDB, error) { 40 | // Validate the schema 41 | if err := schema.Validate(); err != nil { 42 | return nil, err 43 | } 44 | 45 | // Create the MemDB 46 | db := &MemDB{ 47 | schema: schema, 48 | root: unsafe.Pointer(iradix.New()), 49 | primary: true, 50 | } 51 | if err := db.initialize(); err != nil { 52 | return nil, err 53 | } 54 | 55 | return db, nil 56 | } 57 | 58 | // DBSchema returns schema in use for introspection. 59 | // 60 | // The method is intended for *read-only* debugging use cases, 61 | // returned schema should *never be modified in-place*. 62 | func (db *MemDB) DBSchema() *DBSchema { 63 | return db.schema 64 | } 65 | 66 | // getRoot is used to do an atomic load of the root pointer 67 | func (db *MemDB) getRoot() *iradix.Tree { 68 | root := (*iradix.Tree)(atomic.LoadPointer(&db.root)) 69 | return root 70 | } 71 | 72 | // Txn is used to start a new transaction in either read or write mode. 73 | // There can only be a single concurrent writer, but any number of readers. 74 | func (db *MemDB) Txn(write bool) *Txn { 75 | if write { 76 | db.writer.Lock() 77 | } 78 | txn := &Txn{ 79 | db: db, 80 | write: write, 81 | rootTxn: db.getRoot().Txn(), 82 | } 83 | return txn 84 | } 85 | 86 | // Snapshot is used to capture a point-in-time snapshot of the database that 87 | // will not be affected by any write operations to the existing DB. 88 | // 89 | // If MemDB is storing reference-based values (pointers, maps, slices, etc.), 90 | // the Snapshot will not deep copy those values. Therefore, it is still unsafe 91 | // to modify any inserted values in either DB. 92 | func (db *MemDB) Snapshot() *MemDB { 93 | clone := &MemDB{ 94 | schema: db.schema, 95 | root: unsafe.Pointer(db.getRoot()), 96 | primary: false, 97 | } 98 | return clone 99 | } 100 | 101 | // initialize is used to setup the DB for use after creation. This should 102 | // be called only once after allocating a MemDB. 103 | func (db *MemDB) initialize() error { 104 | root := db.getRoot() 105 | for tName, tableSchema := range db.schema.Tables { 106 | for iName := range tableSchema.Indexes { 107 | index := iradix.New() 108 | path := indexPath(tName, iName) 109 | root, _, _ = root.Insert(path, index) 110 | } 111 | } 112 | db.root = unsafe.Pointer(root) 113 | return nil 114 | } 115 | 116 | // indexPath returns the path from the root to the given table index 117 | func indexPath(table, index string) []byte { 118 | return []byte(table + "." + index) 119 | } 120 | -------------------------------------------------------------------------------- /memdb_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package memdb 5 | 6 | import ( 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestMemDB_SingleWriter_MultiReader(t *testing.T) { 12 | db, err := NewMemDB(testValidSchema()) 13 | if err != nil { 14 | t.Fatalf("err: %v", err) 15 | } 16 | 17 | tx1 := db.Txn(true) 18 | tx2 := db.Txn(false) // Should not block! 19 | tx3 := db.Txn(false) // Should not block! 20 | tx4 := db.Txn(false) // Should not block! 21 | 22 | doneCh := make(chan struct{}) 23 | go func() { 24 | defer close(doneCh) 25 | db.Txn(true) 26 | }() 27 | 28 | select { 29 | case <-doneCh: 30 | t.Fatalf("should not allow another writer") 31 | case <-time.After(10 * time.Millisecond): 32 | } 33 | 34 | tx1.Abort() 35 | tx2.Abort() 36 | tx3.Abort() 37 | tx4.Abort() 38 | 39 | select { 40 | case <-doneCh: 41 | case <-time.After(10 * time.Millisecond): 42 | t.Fatalf("should allow another writer") 43 | } 44 | } 45 | 46 | func TestMemDB_Snapshot(t *testing.T) { 47 | db, err := NewMemDB(testValidSchema()) 48 | if err != nil { 49 | t.Fatalf("err: %v", err) 50 | } 51 | 52 | // Add an object 53 | obj := testObj() 54 | txn := db.Txn(true) 55 | if err := txn.Insert("main", obj); err != nil { 56 | t.Fatalf("err: %v", err) 57 | } 58 | txn.Commit() 59 | 60 | // Clone the db 61 | db2 := db.Snapshot() 62 | 63 | // Remove the object 64 | txn = db.Txn(true) 65 | if err := txn.Delete("main", obj); err != nil { 66 | t.Fatalf("err: %v", err) 67 | } 68 | txn.Commit() 69 | 70 | // Object should exist in second snapshot but not first 71 | txn = db.Txn(false) 72 | out, err := txn.First("main", "id", obj.ID) 73 | if err != nil { 74 | t.Fatalf("err: %v", err) 75 | } 76 | if out != nil { 77 | t.Fatalf("should not exist %#v", out) 78 | } 79 | 80 | txn = db2.Txn(true) 81 | out, err = txn.First("main", "id", obj.ID) 82 | if err != nil { 83 | t.Fatalf("err: %v", err) 84 | } 85 | if out == nil { 86 | t.Fatalf("should exist") 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /schema.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package memdb 5 | 6 | import "fmt" 7 | 8 | // DBSchema is the schema to use for the full database with a MemDB instance. 9 | // 10 | // MemDB will require a valid schema. Schema validation can be tested using 11 | // the Validate function. Calling this function is recommended in unit tests. 12 | type DBSchema struct { 13 | // Tables is the set of tables within this database. The key is the 14 | // table name and must match the Name in TableSchema. 15 | Tables map[string]*TableSchema 16 | } 17 | 18 | // Validate validates the schema. 19 | func (s *DBSchema) Validate() error { 20 | if s == nil { 21 | return fmt.Errorf("schema is nil") 22 | } 23 | 24 | if len(s.Tables) == 0 { 25 | return fmt.Errorf("schema has no tables defined") 26 | } 27 | 28 | for name, table := range s.Tables { 29 | if name != table.Name { 30 | return fmt.Errorf("table name mis-match for '%s'", name) 31 | } 32 | 33 | if err := table.Validate(); err != nil { 34 | return fmt.Errorf("table %q: %s", name, err) 35 | } 36 | } 37 | 38 | return nil 39 | } 40 | 41 | // TableSchema is the schema for a single table. 42 | type TableSchema struct { 43 | // Name of the table. This must match the key in the Tables map in DBSchema. 44 | Name string 45 | 46 | // Indexes is the set of indexes for querying this table. The key 47 | // is a unique name for the index and must match the Name in the 48 | // IndexSchema. 49 | Indexes map[string]*IndexSchema 50 | } 51 | 52 | // Validate is used to validate the table schema 53 | func (s *TableSchema) Validate() error { 54 | if s.Name == "" { 55 | return fmt.Errorf("missing table name") 56 | } 57 | 58 | if len(s.Indexes) == 0 { 59 | return fmt.Errorf("missing table indexes for '%s'", s.Name) 60 | } 61 | 62 | if _, ok := s.Indexes["id"]; !ok { 63 | return fmt.Errorf("must have id index") 64 | } 65 | 66 | if !s.Indexes["id"].Unique { 67 | return fmt.Errorf("id index must be unique") 68 | } 69 | 70 | if _, ok := s.Indexes["id"].Indexer.(SingleIndexer); !ok { 71 | return fmt.Errorf("id index must be a SingleIndexer") 72 | } 73 | 74 | for name, index := range s.Indexes { 75 | if name != index.Name { 76 | return fmt.Errorf("index name mis-match for '%s'", name) 77 | } 78 | 79 | if err := index.Validate(); err != nil { 80 | return fmt.Errorf("index %q: %s", name, err) 81 | } 82 | } 83 | 84 | return nil 85 | } 86 | 87 | // IndexSchema is the schema for an index. An index defines how a table is 88 | // queried. 89 | type IndexSchema struct { 90 | // Name of the index. This must be unique among a tables set of indexes. 91 | // This must match the key in the map of Indexes for a TableSchema. 92 | Name string 93 | 94 | // AllowMissing if true ignores this index if it doesn't produce a 95 | // value. For example, an index that extracts a field that doesn't 96 | // exist from a structure. 97 | AllowMissing bool 98 | 99 | Unique bool 100 | Indexer Indexer 101 | } 102 | 103 | func (s *IndexSchema) Validate() error { 104 | if s.Name == "" { 105 | return fmt.Errorf("missing index name") 106 | } 107 | if s.Indexer == nil { 108 | return fmt.Errorf("missing index function for '%s'", s.Name) 109 | } 110 | switch s.Indexer.(type) { 111 | case SingleIndexer: 112 | case MultiIndexer: 113 | default: 114 | return fmt.Errorf("indexer for '%s' must be a SingleIndexer or MultiIndexer", s.Name) 115 | } 116 | return nil 117 | } 118 | -------------------------------------------------------------------------------- /schema_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package memdb 5 | 6 | import "testing" 7 | 8 | func testValidSchema() *DBSchema { 9 | return &DBSchema{ 10 | Tables: map[string]*TableSchema{ 11 | "main": &TableSchema{ 12 | Name: "main", 13 | Indexes: map[string]*IndexSchema{ 14 | "id": &IndexSchema{ 15 | Name: "id", 16 | Unique: true, 17 | Indexer: &StringFieldIndex{Field: "ID"}, 18 | }, 19 | "foo": &IndexSchema{ 20 | Name: "foo", 21 | Indexer: &StringFieldIndex{Field: "Foo"}, 22 | }, 23 | "qux": &IndexSchema{ 24 | Name: "qux", 25 | Indexer: &StringSliceFieldIndex{Field: "Qux"}, 26 | }, 27 | }, 28 | }, 29 | }, 30 | } 31 | } 32 | 33 | func TestDBSchema_Validate(t *testing.T) { 34 | s := &DBSchema{} 35 | err := s.Validate() 36 | if err == nil { 37 | t.Fatalf("should not validate, empty") 38 | } 39 | 40 | s.Tables = map[string]*TableSchema{ 41 | "foo": &TableSchema{Name: "foo"}, 42 | } 43 | err = s.Validate() 44 | if err == nil { 45 | t.Fatalf("should not validate, no indexes") 46 | } 47 | 48 | valid := testValidSchema() 49 | err = valid.Validate() 50 | if err != nil { 51 | t.Fatalf("should validate: %v", err) 52 | } 53 | } 54 | 55 | func TestTableSchema_Validate(t *testing.T) { 56 | s := &TableSchema{} 57 | err := s.Validate() 58 | if err == nil { 59 | t.Fatalf("should not validate, empty") 60 | } 61 | 62 | s.Indexes = map[string]*IndexSchema{ 63 | "foo": &IndexSchema{Name: "foo"}, 64 | } 65 | err = s.Validate() 66 | if err == nil { 67 | t.Fatalf("should not validate, no indexes") 68 | } 69 | 70 | valid := &TableSchema{ 71 | Name: "main", 72 | Indexes: map[string]*IndexSchema{ 73 | "id": &IndexSchema{ 74 | Name: "id", 75 | Unique: true, 76 | Indexer: &StringFieldIndex{Field: "ID", Lowercase: true}, 77 | }, 78 | }, 79 | } 80 | err = valid.Validate() 81 | if err != nil { 82 | t.Fatalf("should validate: %v", err) 83 | } 84 | } 85 | 86 | func TestIndexSchema_Validate(t *testing.T) { 87 | s := &IndexSchema{} 88 | err := s.Validate() 89 | if err == nil { 90 | t.Fatalf("should not validate, empty") 91 | } 92 | 93 | s.Name = "foo" 94 | err = s.Validate() 95 | if err == nil { 96 | t.Fatalf("should not validate, no indexer") 97 | } 98 | 99 | s.Indexer = &StringFieldIndex{Field: "Foo", Lowercase: false} 100 | err = s.Validate() 101 | if err != nil { 102 | t.Fatalf("should validate: %v", err) 103 | } 104 | 105 | s.Indexer = &StringSliceFieldIndex{Field: "Qux", Lowercase: false} 106 | err = s.Validate() 107 | if err != nil { 108 | t.Fatalf("should validate: %v", err) 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /txn.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package memdb 5 | 6 | import ( 7 | "bytes" 8 | "fmt" 9 | "strings" 10 | "sync/atomic" 11 | "unsafe" 12 | 13 | iradix "github.com/hashicorp/go-immutable-radix" 14 | ) 15 | 16 | const ( 17 | id = "id" 18 | ) 19 | 20 | var ( 21 | // ErrNotFound is returned when the requested item is not found 22 | ErrNotFound = fmt.Errorf("not found") 23 | ) 24 | 25 | // tableIndex is a tuple of (Table, Index) used for lookups 26 | type tableIndex struct { 27 | Table string 28 | Index string 29 | } 30 | 31 | // Txn is a transaction against a MemDB. 32 | // This can be a read or write transaction. 33 | type Txn struct { 34 | db *MemDB 35 | write bool 36 | rootTxn *iradix.Txn 37 | after []func() 38 | 39 | // changes is used to track the changes performed during the transaction. If 40 | // it is nil at transaction start then changes are not tracked. 41 | changes Changes 42 | 43 | modified map[tableIndex]*iradix.Txn 44 | } 45 | 46 | // TrackChanges enables change tracking for the transaction. If called at any 47 | // point before commit, subsequent mutations will be recorded and can be 48 | // retrieved using ChangeSet. Once this has been called on a transaction it 49 | // can't be unset. As with other Txn methods it's not safe to call this from a 50 | // different goroutine than the one making mutations or committing the 51 | // transaction. 52 | func (txn *Txn) TrackChanges() { 53 | if txn.changes == nil { 54 | txn.changes = make(Changes, 0, 1) 55 | } 56 | } 57 | 58 | // readableIndex returns a transaction usable for reading the given index in a 59 | // table. If the transaction is a write transaction with modifications, a clone of the 60 | // modified index will be returned. 61 | func (txn *Txn) readableIndex(table, index string) *iradix.Txn { 62 | // Look for existing transaction 63 | if txn.write && txn.modified != nil { 64 | key := tableIndex{table, index} 65 | exist, ok := txn.modified[key] 66 | if ok { 67 | return exist.Clone() 68 | } 69 | } 70 | 71 | // Create a read transaction 72 | path := indexPath(table, index) 73 | raw, _ := txn.rootTxn.Get(path) 74 | indexTxn := raw.(*iradix.Tree).Txn() 75 | return indexTxn 76 | } 77 | 78 | // writableIndex returns a transaction usable for modifying the 79 | // given index in a table. 80 | func (txn *Txn) writableIndex(table, index string) *iradix.Txn { 81 | if txn.modified == nil { 82 | txn.modified = make(map[tableIndex]*iradix.Txn) 83 | } 84 | 85 | // Look for existing transaction 86 | key := tableIndex{table, index} 87 | exist, ok := txn.modified[key] 88 | if ok { 89 | return exist 90 | } 91 | 92 | // Start a new transaction 93 | path := indexPath(table, index) 94 | raw, _ := txn.rootTxn.Get(path) 95 | indexTxn := raw.(*iradix.Tree).Txn() 96 | 97 | // If we are the primary DB, enable mutation tracking. Snapshots should 98 | // not notify, otherwise we will trigger watches on the primary DB when 99 | // the writes will not be visible. 100 | indexTxn.TrackMutate(txn.db.primary) 101 | 102 | // Keep this open for the duration of the txn 103 | txn.modified[key] = indexTxn 104 | return indexTxn 105 | } 106 | 107 | // Abort is used to cancel this transaction. 108 | // This is a noop for read transactions, 109 | // already aborted or commited transactions. 110 | func (txn *Txn) Abort() { 111 | // Noop for a read transaction 112 | if !txn.write { 113 | return 114 | } 115 | 116 | // Check if already aborted or committed 117 | if txn.rootTxn == nil { 118 | return 119 | } 120 | 121 | // Clear the txn 122 | txn.rootTxn = nil 123 | txn.modified = nil 124 | txn.changes = nil 125 | 126 | // Release the writer lock since this is invalid 127 | txn.db.writer.Unlock() 128 | } 129 | 130 | // Commit is used to finalize this transaction. 131 | // This is a noop for read transactions, 132 | // already aborted or committed transactions. 133 | func (txn *Txn) Commit() { 134 | // Noop for a read transaction 135 | if !txn.write { 136 | return 137 | } 138 | 139 | // Check if already aborted or committed 140 | if txn.rootTxn == nil { 141 | return 142 | } 143 | 144 | // Commit each sub-transaction scoped to (table, index) 145 | for key, subTxn := range txn.modified { 146 | path := indexPath(key.Table, key.Index) 147 | final := subTxn.CommitOnly() 148 | txn.rootTxn.Insert(path, final) 149 | } 150 | 151 | // Update the root of the DB 152 | newRoot := txn.rootTxn.CommitOnly() 153 | atomic.StorePointer(&txn.db.root, unsafe.Pointer(newRoot)) 154 | 155 | // Now issue all of the mutation updates (this is safe to call 156 | // even if mutation tracking isn't enabled); we do this after 157 | // the root pointer is swapped so that waking responders will 158 | // see the new state. 159 | for _, subTxn := range txn.modified { 160 | subTxn.Notify() 161 | } 162 | txn.rootTxn.Notify() 163 | 164 | // Clear the txn 165 | txn.rootTxn = nil 166 | txn.modified = nil 167 | 168 | // Release the writer lock since this is invalid 169 | txn.db.writer.Unlock() 170 | 171 | // Run the deferred functions, if any 172 | for i := len(txn.after); i > 0; i-- { 173 | fn := txn.after[i-1] 174 | fn() 175 | } 176 | } 177 | 178 | // Insert is used to add or update an object into the given table. 179 | // 180 | // When updating an object, the obj provided should be a copy rather 181 | // than a value updated in-place. Modifying values in-place that are already 182 | // inserted into MemDB is not supported behavior. 183 | func (txn *Txn) Insert(table string, obj interface{}) error { 184 | if !txn.write { 185 | return fmt.Errorf("cannot insert in read-only transaction") 186 | } 187 | 188 | // Get the table schema 189 | tableSchema, ok := txn.db.schema.Tables[table] 190 | if !ok { 191 | return fmt.Errorf("invalid table '%s'", table) 192 | } 193 | 194 | // Get the primary ID of the object 195 | idSchema := tableSchema.Indexes[id] 196 | idIndexer := idSchema.Indexer.(SingleIndexer) 197 | ok, idVal, err := idIndexer.FromObject(obj) 198 | if err != nil { 199 | return fmt.Errorf("failed to build primary index: %v", err) 200 | } 201 | if !ok { 202 | return fmt.Errorf("object missing primary index") 203 | } 204 | 205 | // Lookup the object by ID first, to see if this is an update 206 | idTxn := txn.writableIndex(table, id) 207 | existing, update := idTxn.Get(idVal) 208 | 209 | // On an update, there is an existing object with the given 210 | // primary ID. We do the update by deleting the current object 211 | // and inserting the new object. 212 | for name, indexSchema := range tableSchema.Indexes { 213 | indexTxn := txn.writableIndex(table, name) 214 | 215 | // Determine the new index value 216 | var ( 217 | ok bool 218 | vals [][]byte 219 | err error 220 | ) 221 | switch indexer := indexSchema.Indexer.(type) { 222 | case SingleIndexer: 223 | var val []byte 224 | ok, val, err = indexer.FromObject(obj) 225 | vals = [][]byte{val} 226 | case MultiIndexer: 227 | ok, vals, err = indexer.FromObject(obj) 228 | } 229 | if err != nil { 230 | return fmt.Errorf("failed to build index '%s': %v", name, err) 231 | } 232 | 233 | // Handle non-unique index by computing a unique index. 234 | // This is done by appending the primary key which must 235 | // be unique anyways. 236 | if ok && !indexSchema.Unique { 237 | for i := range vals { 238 | vals[i] = append(vals[i], idVal...) 239 | } 240 | } 241 | 242 | // Handle the update by deleting from the index first 243 | if update { 244 | var ( 245 | okExist bool 246 | valsExist [][]byte 247 | err error 248 | ) 249 | switch indexer := indexSchema.Indexer.(type) { 250 | case SingleIndexer: 251 | var valExist []byte 252 | okExist, valExist, err = indexer.FromObject(existing) 253 | valsExist = [][]byte{valExist} 254 | case MultiIndexer: 255 | okExist, valsExist, err = indexer.FromObject(existing) 256 | } 257 | if err != nil { 258 | return fmt.Errorf("failed to build index '%s': %v", name, err) 259 | } 260 | if okExist { 261 | for i, valExist := range valsExist { 262 | // Handle non-unique index by computing a unique index. 263 | // This is done by appending the primary key which must 264 | // be unique anyways. 265 | if !indexSchema.Unique { 266 | valExist = append(valExist, idVal...) 267 | } 268 | 269 | // If we are writing to the same index with the same value, 270 | // we can avoid the delete as the insert will overwrite the 271 | // value anyways. 272 | if i >= len(vals) || !bytes.Equal(valExist, vals[i]) { 273 | indexTxn.Delete(valExist) 274 | } 275 | } 276 | } 277 | } 278 | 279 | // If there is no index value, either this is an error or an expected 280 | // case and we can skip updating 281 | if !ok { 282 | if indexSchema.AllowMissing { 283 | continue 284 | } else { 285 | return fmt.Errorf("missing value for index '%s'", name) 286 | } 287 | } 288 | 289 | // Update the value of the index 290 | for _, val := range vals { 291 | indexTxn.Insert(val, obj) 292 | } 293 | } 294 | if txn.changes != nil { 295 | txn.changes = append(txn.changes, Change{ 296 | Table: table, 297 | Before: existing, // might be nil on a create 298 | After: obj, 299 | primaryKey: idVal, 300 | }) 301 | } 302 | return nil 303 | } 304 | 305 | // Delete is used to delete a single object from the given table. 306 | // This object must already exist in the table. 307 | func (txn *Txn) Delete(table string, obj interface{}) error { 308 | if !txn.write { 309 | return fmt.Errorf("cannot delete in read-only transaction") 310 | } 311 | 312 | // Get the table schema 313 | tableSchema, ok := txn.db.schema.Tables[table] 314 | if !ok { 315 | return fmt.Errorf("invalid table '%s'", table) 316 | } 317 | 318 | // Get the primary ID of the object 319 | idSchema := tableSchema.Indexes[id] 320 | idIndexer := idSchema.Indexer.(SingleIndexer) 321 | ok, idVal, err := idIndexer.FromObject(obj) 322 | if err != nil { 323 | return fmt.Errorf("failed to build primary index: %v", err) 324 | } 325 | if !ok { 326 | return fmt.Errorf("object missing primary index") 327 | } 328 | 329 | // Lookup the object by ID first, check if we should continue 330 | idTxn := txn.writableIndex(table, id) 331 | existing, ok := idTxn.Get(idVal) 332 | if !ok { 333 | return ErrNotFound 334 | } 335 | 336 | // Remove the object from all the indexes 337 | for name, indexSchema := range tableSchema.Indexes { 338 | indexTxn := txn.writableIndex(table, name) 339 | 340 | // Handle the update by deleting from the index first 341 | var ( 342 | ok bool 343 | vals [][]byte 344 | err error 345 | ) 346 | switch indexer := indexSchema.Indexer.(type) { 347 | case SingleIndexer: 348 | var val []byte 349 | ok, val, err = indexer.FromObject(existing) 350 | vals = [][]byte{val} 351 | case MultiIndexer: 352 | ok, vals, err = indexer.FromObject(existing) 353 | } 354 | if err != nil { 355 | return fmt.Errorf("failed to build index '%s': %v", name, err) 356 | } 357 | if ok { 358 | // Handle non-unique index by computing a unique index. 359 | // This is done by appending the primary key which must 360 | // be unique anyways. 361 | for _, val := range vals { 362 | if !indexSchema.Unique { 363 | val = append(val, idVal...) 364 | } 365 | indexTxn.Delete(val) 366 | } 367 | } 368 | } 369 | if txn.changes != nil { 370 | txn.changes = append(txn.changes, Change{ 371 | Table: table, 372 | Before: existing, 373 | After: nil, // Now nil indicates deletion 374 | primaryKey: idVal, 375 | }) 376 | } 377 | return nil 378 | } 379 | 380 | // DeletePrefix is used to delete an entire subtree based on a prefix. 381 | // The given index must be a prefix index, and will be used to perform a scan and enumerate the set of objects to delete. 382 | // These will be removed from all other indexes, and then a special prefix operation will delete the objects from the given index in an efficient subtree delete operation. 383 | // This is useful when you have a very large number of objects indexed by the given index, along with a much smaller number of entries in the other indexes for those objects. 384 | func (txn *Txn) DeletePrefix(table string, prefix_index string, prefix string) (bool, error) { 385 | if !txn.write { 386 | return false, fmt.Errorf("cannot delete in read-only transaction") 387 | } 388 | 389 | if !strings.HasSuffix(prefix_index, "_prefix") { 390 | return false, fmt.Errorf("Index name for DeletePrefix must be a prefix index, Got %v ", prefix_index) 391 | } 392 | 393 | deletePrefixIndex := strings.TrimSuffix(prefix_index, "_prefix") 394 | 395 | // Get an iterator over all of the keys with the given prefix. 396 | entries, err := txn.Get(table, prefix_index, prefix) 397 | if err != nil { 398 | return false, fmt.Errorf("failed kvs lookup: %s", err) 399 | } 400 | // Get the table schema 401 | tableSchema, ok := txn.db.schema.Tables[table] 402 | if !ok { 403 | return false, fmt.Errorf("invalid table '%s'", table) 404 | } 405 | 406 | foundAny := false 407 | for entry := entries.Next(); entry != nil; entry = entries.Next() { 408 | if !foundAny { 409 | foundAny = true 410 | } 411 | // Get the primary ID of the object 412 | idSchema := tableSchema.Indexes[id] 413 | idIndexer := idSchema.Indexer.(SingleIndexer) 414 | ok, idVal, err := idIndexer.FromObject(entry) 415 | if err != nil { 416 | return false, fmt.Errorf("failed to build primary index: %v", err) 417 | } 418 | if !ok { 419 | return false, fmt.Errorf("object missing primary index") 420 | } 421 | if txn.changes != nil { 422 | // Record the deletion 423 | idTxn := txn.writableIndex(table, id) 424 | existing, ok := idTxn.Get(idVal) 425 | if ok { 426 | txn.changes = append(txn.changes, Change{ 427 | Table: table, 428 | Before: existing, 429 | After: nil, // Now nil indicates deletion 430 | primaryKey: idVal, 431 | }) 432 | } 433 | } 434 | // Remove the object from all the indexes except the given prefix index 435 | for name, indexSchema := range tableSchema.Indexes { 436 | if name == deletePrefixIndex { 437 | continue 438 | } 439 | indexTxn := txn.writableIndex(table, name) 440 | 441 | // Handle the update by deleting from the index first 442 | var ( 443 | ok bool 444 | vals [][]byte 445 | err error 446 | ) 447 | switch indexer := indexSchema.Indexer.(type) { 448 | case SingleIndexer: 449 | var val []byte 450 | ok, val, err = indexer.FromObject(entry) 451 | vals = [][]byte{val} 452 | case MultiIndexer: 453 | ok, vals, err = indexer.FromObject(entry) 454 | } 455 | if err != nil { 456 | return false, fmt.Errorf("failed to build index '%s': %v", name, err) 457 | } 458 | 459 | if ok { 460 | // Handle non-unique index by computing a unique index. 461 | // This is done by appending the primary key which must 462 | // be unique anyways. 463 | for _, val := range vals { 464 | if !indexSchema.Unique { 465 | val = append(val, idVal...) 466 | } 467 | indexTxn.Delete(val) 468 | } 469 | } 470 | } 471 | 472 | } 473 | if foundAny { 474 | indexTxn := txn.writableIndex(table, deletePrefixIndex) 475 | ok = indexTxn.DeletePrefix([]byte(prefix)) 476 | if !ok { 477 | panic(fmt.Errorf("prefix %v matched some entries but DeletePrefix did not delete any ", prefix)) 478 | } 479 | return true, nil 480 | } 481 | return false, nil 482 | } 483 | 484 | // DeleteAll is used to delete all the objects in a given table 485 | // matching the constraints on the index 486 | func (txn *Txn) DeleteAll(table, index string, args ...interface{}) (int, error) { 487 | if !txn.write { 488 | return 0, fmt.Errorf("cannot delete in read-only transaction") 489 | } 490 | 491 | // Get all the objects 492 | iter, err := txn.Get(table, index, args...) 493 | if err != nil { 494 | return 0, err 495 | } 496 | 497 | // Put them into a slice so there are no safety concerns while actually 498 | // performing the deletes 499 | var objs []interface{} 500 | for { 501 | obj := iter.Next() 502 | if obj == nil { 503 | break 504 | } 505 | 506 | objs = append(objs, obj) 507 | } 508 | 509 | // Do the deletes 510 | num := 0 511 | for _, obj := range objs { 512 | if err := txn.Delete(table, obj); err != nil { 513 | return num, err 514 | } 515 | num++ 516 | } 517 | return num, nil 518 | } 519 | 520 | // FirstWatch is used to return the first matching object for 521 | // the given constraints on the index along with the watch channel. 522 | // 523 | // Note that all values read in the transaction form a consistent snapshot 524 | // from the time when the transaction was created. 525 | // 526 | // The watch channel is closed when a subsequent write transaction 527 | // has updated the result of the query. Since each read transaction 528 | // operates on an isolated snapshot, a new read transaction must be 529 | // started to observe the changes that have been made. 530 | // 531 | // If the value of index ends with "_prefix", FirstWatch will perform a prefix 532 | // match instead of full match on the index. The registered indexer must implement 533 | // PrefixIndexer, otherwise an error is returned. 534 | func (txn *Txn) FirstWatch(table, index string, args ...interface{}) (<-chan struct{}, interface{}, error) { 535 | // Get the index value 536 | indexSchema, val, err := txn.getIndexValue(table, index, args...) 537 | if err != nil { 538 | return nil, nil, err 539 | } 540 | 541 | // Get the index itself 542 | indexTxn := txn.readableIndex(table, indexSchema.Name) 543 | 544 | // Do an exact lookup 545 | if indexSchema.Unique && val != nil && indexSchema.Name == index { 546 | watch, obj, ok := indexTxn.GetWatch(val) 547 | if !ok { 548 | return watch, nil, nil 549 | } 550 | return watch, obj, nil 551 | } 552 | 553 | // Handle non-unique index by using an iterator and getting the first value 554 | iter := indexTxn.Root().Iterator() 555 | watch := iter.SeekPrefixWatch(val) 556 | _, value, _ := iter.Next() 557 | return watch, value, nil 558 | } 559 | 560 | // LastWatch is used to return the last matching object for 561 | // the given constraints on the index along with the watch channel. 562 | // 563 | // Note that all values read in the transaction form a consistent snapshot 564 | // from the time when the transaction was created. 565 | // 566 | // The watch channel is closed when a subsequent write transaction 567 | // has updated the result of the query. Since each read transaction 568 | // operates on an isolated snapshot, a new read transaction must be 569 | // started to observe the changes that have been made. 570 | // 571 | // If the value of index ends with "_prefix", LastWatch will perform a prefix 572 | // match instead of full match on the index. The registered indexer must implement 573 | // PrefixIndexer, otherwise an error is returned. 574 | func (txn *Txn) LastWatch(table, index string, args ...interface{}) (<-chan struct{}, interface{}, error) { 575 | // Get the index value 576 | indexSchema, val, err := txn.getIndexValue(table, index, args...) 577 | if err != nil { 578 | return nil, nil, err 579 | } 580 | 581 | // Get the index itself 582 | indexTxn := txn.readableIndex(table, indexSchema.Name) 583 | 584 | // Do an exact lookup 585 | if indexSchema.Unique && val != nil && indexSchema.Name == index { 586 | watch, obj, ok := indexTxn.GetWatch(val) 587 | if !ok { 588 | return watch, nil, nil 589 | } 590 | return watch, obj, nil 591 | } 592 | 593 | // Handle non-unique index by using an iterator and getting the last value 594 | iter := indexTxn.Root().ReverseIterator() 595 | watch := iter.SeekPrefixWatch(val) 596 | _, value, _ := iter.Previous() 597 | return watch, value, nil 598 | } 599 | 600 | // First is used to return the first matching object for 601 | // the given constraints on the index. 602 | // 603 | // Note that all values read in the transaction form a consistent snapshot 604 | // from the time when the transaction was created. 605 | func (txn *Txn) First(table, index string, args ...interface{}) (interface{}, error) { 606 | _, val, err := txn.FirstWatch(table, index, args...) 607 | return val, err 608 | } 609 | 610 | // Last is used to return the last matching object for 611 | // the given constraints on the index. 612 | // 613 | // Note that all values read in the transaction form a consistent snapshot 614 | // from the time when the transaction was created. 615 | func (txn *Txn) Last(table, index string, args ...interface{}) (interface{}, error) { 616 | _, val, err := txn.LastWatch(table, index, args...) 617 | return val, err 618 | } 619 | 620 | // LongestPrefix is used to fetch the longest prefix match for the given 621 | // constraints on the index. Note that this will not work with the memdb 622 | // StringFieldIndex because it adds null terminators which prevent the 623 | // algorithm from correctly finding a match (it will get to right before the 624 | // null and fail to find a leaf node). This should only be used where the prefix 625 | // given is capable of matching indexed entries directly, which typically only 626 | // applies to a custom indexer. See the unit test for an example. 627 | // 628 | // Note that all values read in the transaction form a consistent snapshot 629 | // from the time when the transaction was created. 630 | func (txn *Txn) LongestPrefix(table, index string, args ...interface{}) (interface{}, error) { 631 | // Enforce that this only works on prefix indexes. 632 | if !strings.HasSuffix(index, "_prefix") { 633 | return nil, fmt.Errorf("must use '%s_prefix' on index", index) 634 | } 635 | 636 | // Get the index value. 637 | indexSchema, val, err := txn.getIndexValue(table, index, args...) 638 | if err != nil { 639 | return nil, err 640 | } 641 | 642 | // This algorithm only makes sense against a unique index, otherwise the 643 | // index keys will have the IDs appended to them. 644 | if !indexSchema.Unique { 645 | return nil, fmt.Errorf("index '%s' is not unique", index) 646 | } 647 | 648 | // Find the longest prefix match with the given index. 649 | indexTxn := txn.readableIndex(table, indexSchema.Name) 650 | if _, value, ok := indexTxn.Root().LongestPrefix(val); ok { 651 | return value, nil 652 | } 653 | return nil, nil 654 | } 655 | 656 | // getIndexValue is used to get the IndexSchema and the value 657 | // used to scan the index given the parameters. This handles prefix based 658 | // scans when the index has the "_prefix" suffix. The index must support 659 | // prefix iteration. 660 | func (txn *Txn) getIndexValue(table, index string, args ...interface{}) (*IndexSchema, []byte, error) { 661 | // Get the table schema 662 | tableSchema, ok := txn.db.schema.Tables[table] 663 | if !ok { 664 | return nil, nil, fmt.Errorf("invalid table '%s'", table) 665 | } 666 | 667 | // Check for a prefix scan 668 | prefixScan := false 669 | if strings.HasSuffix(index, "_prefix") { 670 | index = strings.TrimSuffix(index, "_prefix") 671 | prefixScan = true 672 | } 673 | 674 | // Get the index schema 675 | indexSchema, ok := tableSchema.Indexes[index] 676 | if !ok { 677 | return nil, nil, fmt.Errorf("invalid index '%s'", index) 678 | } 679 | 680 | // Hot-path for when there are no arguments 681 | if len(args) == 0 { 682 | return indexSchema, nil, nil 683 | } 684 | 685 | // Special case the prefix scanning 686 | if prefixScan { 687 | prefixIndexer, ok := indexSchema.Indexer.(PrefixIndexer) 688 | if !ok { 689 | return indexSchema, nil, 690 | fmt.Errorf("index '%s' does not support prefix scanning", index) 691 | } 692 | 693 | val, err := prefixIndexer.PrefixFromArgs(args...) 694 | if err != nil { 695 | return indexSchema, nil, fmt.Errorf("index error: %v", err) 696 | } 697 | return indexSchema, val, err 698 | } 699 | 700 | // Get the exact match index 701 | val, err := indexSchema.Indexer.FromArgs(args...) 702 | if err != nil { 703 | return indexSchema, nil, fmt.Errorf("index error: %v", err) 704 | } 705 | return indexSchema, val, err 706 | } 707 | 708 | // ResultIterator is used to iterate over a list of results from a query on a table. 709 | // 710 | // When a ResultIterator is created from a write transaction, the results from 711 | // Next will reflect a snapshot of the table at the time the ResultIterator is 712 | // created. 713 | // This means that calling Insert or Delete on a transaction while iterating is 714 | // allowed, but the changes made by Insert or Delete will not be observed in the 715 | // results returned from subsequent calls to Next. For example if an item is deleted 716 | // from the index used by the iterator it will still be returned by Next. If an 717 | // item is inserted into the index used by the iterator, it will not be returned 718 | // by Next. However, an iterator created after a call to Insert or Delete will 719 | // reflect the modifications. 720 | // 721 | // When a ResultIterator is created from a write transaction, and there are already 722 | // modifications to the index used by the iterator, the modification cache of the 723 | // index will be invalidated. This may result in some additional allocations if 724 | // the same node in the index is modified again. 725 | type ResultIterator interface { 726 | WatchCh() <-chan struct{} 727 | // Next returns the next result from the iterator. If there are no more results 728 | // nil is returned. 729 | Next() interface{} 730 | } 731 | 732 | // Get is used to construct a ResultIterator over all the rows that match the 733 | // given constraints of an index. The index values must match exactly (this 734 | // is not a range-based or prefix-based lookup) by default. 735 | // 736 | // Prefix lookups: if the named index implements PrefixIndexer, you may perform 737 | // prefix-based lookups by appending "_prefix" to the index name. In this 738 | // scenario, the index values given in args are treated as prefix lookups. For 739 | // example, a StringFieldIndex will match any string with the given value 740 | // as a prefix: "mem" matches "memdb". 741 | // 742 | // See the documentation for ResultIterator to understand the behaviour of the 743 | // returned ResultIterator. 744 | func (txn *Txn) Get(table, index string, args ...interface{}) (ResultIterator, error) { 745 | indexIter, val, err := txn.getIndexIterator(table, index, args...) 746 | if err != nil { 747 | return nil, err 748 | } 749 | 750 | // Seek the iterator to the appropriate sub-set 751 | watchCh := indexIter.SeekPrefixWatch(val) 752 | 753 | // Create an iterator 754 | iter := &radixIterator{ 755 | iter: indexIter, 756 | watchCh: watchCh, 757 | } 758 | return iter, nil 759 | } 760 | 761 | // GetReverse is used to construct a Reverse ResultIterator over all the 762 | // rows that match the given constraints of an index. 763 | // The returned ResultIterator's Next() will return the next Previous value. 764 | // 765 | // See the documentation on Get for details on arguments. 766 | // 767 | // See the documentation for ResultIterator to understand the behaviour of the 768 | // returned ResultIterator. 769 | func (txn *Txn) GetReverse(table, index string, args ...interface{}) (ResultIterator, error) { 770 | indexIter, val, err := txn.getIndexIteratorReverse(table, index, args...) 771 | if err != nil { 772 | return nil, err 773 | } 774 | 775 | // Seek the iterator to the appropriate sub-set 776 | watchCh := indexIter.SeekPrefixWatch(val) 777 | 778 | // Create an iterator 779 | iter := &radixReverseIterator{ 780 | iter: indexIter, 781 | watchCh: watchCh, 782 | } 783 | return iter, nil 784 | } 785 | 786 | // LowerBound is used to construct a ResultIterator over all the the range of 787 | // rows that have an index value greater than or equal to the provide args. 788 | // Calling this then iterating until the rows are larger than required allows 789 | // range scans within an index. It is not possible to watch the resulting 790 | // iterator since the radix tree doesn't efficiently allow watching on lower 791 | // bound changes. The WatchCh returned will be nill and so will block forever. 792 | // 793 | // If the value of index ends with "_prefix", LowerBound will perform a prefix match instead of 794 | // a full match on the index. The registered index must implement PrefixIndexer, 795 | // otherwise an error is returned. 796 | // 797 | // See the documentation for ResultIterator to understand the behaviour of the 798 | // returned ResultIterator. 799 | func (txn *Txn) LowerBound(table, index string, args ...interface{}) (ResultIterator, error) { 800 | indexIter, val, err := txn.getIndexIterator(table, index, args...) 801 | if err != nil { 802 | return nil, err 803 | } 804 | 805 | // Seek the iterator to the appropriate sub-set 806 | indexIter.SeekLowerBound(val) 807 | 808 | // Create an iterator 809 | iter := &radixIterator{ 810 | iter: indexIter, 811 | } 812 | return iter, nil 813 | } 814 | 815 | // ReverseLowerBound is used to construct a Reverse ResultIterator over all the 816 | // the range of rows that have an index value less than or equal to the 817 | // provide args. Calling this then iterating until the rows are lower than 818 | // required allows range scans within an index. It is not possible to watch the 819 | // resulting iterator since the radix tree doesn't efficiently allow watching 820 | // on lower bound changes. The WatchCh returned will be nill and so will block 821 | // forever. 822 | // 823 | // See the documentation for ResultIterator to understand the behaviour of the 824 | // returned ResultIterator. 825 | func (txn *Txn) ReverseLowerBound(table, index string, args ...interface{}) (ResultIterator, error) { 826 | indexIter, val, err := txn.getIndexIteratorReverse(table, index, args...) 827 | if err != nil { 828 | return nil, err 829 | } 830 | 831 | // Seek the iterator to the appropriate sub-set 832 | indexIter.SeekReverseLowerBound(val) 833 | 834 | // Create an iterator 835 | iter := &radixReverseIterator{ 836 | iter: indexIter, 837 | } 838 | return iter, nil 839 | } 840 | 841 | // objectID is a tuple of table name and the raw internal id byte slice 842 | // converted to a string. It's only converted to a string to make it comparable 843 | // so this struct can be used as a map index. 844 | type objectID struct { 845 | Table string 846 | IndexVal string 847 | } 848 | 849 | // mutInfo stores metadata about mutations to allow collapsing multiple 850 | // mutations to the same object into one. 851 | type mutInfo struct { 852 | firstBefore interface{} 853 | lastIdx int 854 | } 855 | 856 | // Changes returns the set of object changes that have been made in the 857 | // transaction so far. If change tracking is not enabled it wil always return 858 | // nil. It can be called before or after Commit. If it is before Commit it will 859 | // return all changes made so far which may not be the same as the final 860 | // Changes. After abort it will always return nil. As with other Txn methods 861 | // it's not safe to call this from a different goroutine than the one making 862 | // mutations or committing the transaction. Mutations will appear in the order 863 | // they were performed in the transaction but multiple operations to the same 864 | // object will be collapsed so only the effective overall change to that object 865 | // is present. If transaction operations are dependent (e.g. copy object X to Y 866 | // then delete X) this might mean the set of mutations is incomplete to verify 867 | // history, but it is complete in that the net effect is preserved (Y got a new 868 | // value, X got removed). 869 | func (txn *Txn) Changes() Changes { 870 | if txn.changes == nil { 871 | return nil 872 | } 873 | 874 | // De-duplicate mutations by key so all take effect at the point of the last 875 | // write but we keep the mutations in order. 876 | dups := make(map[objectID]mutInfo) 877 | for i, m := range txn.changes { 878 | oid := objectID{ 879 | Table: m.Table, 880 | IndexVal: string(m.primaryKey), 881 | } 882 | // Store the latest mutation index for each key value 883 | mi, ok := dups[oid] 884 | if !ok { 885 | // First entry for key, store the before value 886 | mi.firstBefore = m.Before 887 | } 888 | mi.lastIdx = i 889 | dups[oid] = mi 890 | } 891 | if len(dups) == len(txn.changes) { 892 | // No duplicates found, fast path return it as is 893 | return txn.changes 894 | } 895 | 896 | // Need to remove the duplicates 897 | cs := make(Changes, 0, len(dups)) 898 | for i, m := range txn.changes { 899 | oid := objectID{ 900 | Table: m.Table, 901 | IndexVal: string(m.primaryKey), 902 | } 903 | mi := dups[oid] 904 | if mi.lastIdx == i { 905 | // This was the latest value for this key copy it with the before value in 906 | // case it's different. Note that m is not a pointer so we are not 907 | // modifying the txn.changeSet here - it's already a copy. 908 | m.Before = mi.firstBefore 909 | 910 | // Edge case - if the object was inserted and then eventually deleted in 911 | // the same transaction, then the net affect on that key is a no-op. Don't 912 | // emit a mutation with nil for before and after as it's meaningless and 913 | // might violate expectations and cause a panic in code that assumes at 914 | // least one must be set. 915 | if m.Before == nil && m.After == nil { 916 | continue 917 | } 918 | cs = append(cs, m) 919 | } 920 | } 921 | // Store the de-duped version in case this is called again 922 | txn.changes = cs 923 | return cs 924 | } 925 | 926 | func (txn *Txn) getIndexIterator(table, index string, args ...interface{}) (*iradix.Iterator, []byte, error) { 927 | // Get the index value to scan 928 | indexSchema, val, err := txn.getIndexValue(table, index, args...) 929 | if err != nil { 930 | return nil, nil, err 931 | } 932 | 933 | // Get the index itself 934 | indexTxn := txn.readableIndex(table, indexSchema.Name) 935 | indexRoot := indexTxn.Root() 936 | 937 | // Get an iterator over the index 938 | indexIter := indexRoot.Iterator() 939 | return indexIter, val, nil 940 | } 941 | 942 | func (txn *Txn) getIndexIteratorReverse(table, index string, args ...interface{}) (*iradix.ReverseIterator, []byte, error) { 943 | // Get the index value to scan 944 | indexSchema, val, err := txn.getIndexValue(table, index, args...) 945 | if err != nil { 946 | return nil, nil, err 947 | } 948 | 949 | // Get the index itself 950 | indexTxn := txn.readableIndex(table, indexSchema.Name) 951 | indexRoot := indexTxn.Root() 952 | 953 | // Get an interator over the index 954 | indexIter := indexRoot.ReverseIterator() 955 | return indexIter, val, nil 956 | } 957 | 958 | // Defer is used to push a new arbitrary function onto a stack which 959 | // gets called when a transaction is committed and finished. Deferred 960 | // functions are called in LIFO order, and only invoked at the end of 961 | // write transactions. 962 | func (txn *Txn) Defer(fn func()) { 963 | txn.after = append(txn.after, fn) 964 | } 965 | 966 | // radixIterator is used to wrap an underlying iradix iterator. 967 | // This is much more efficient than a sliceIterator as we are not 968 | // materializing the entire view. 969 | type radixIterator struct { 970 | iter *iradix.Iterator 971 | watchCh <-chan struct{} 972 | } 973 | 974 | func (r *radixIterator) WatchCh() <-chan struct{} { 975 | return r.watchCh 976 | } 977 | 978 | func (r *radixIterator) Next() interface{} { 979 | _, value, ok := r.iter.Next() 980 | if !ok { 981 | return nil 982 | } 983 | return value 984 | } 985 | 986 | type radixReverseIterator struct { 987 | iter *iradix.ReverseIterator 988 | watchCh <-chan struct{} 989 | } 990 | 991 | func (r *radixReverseIterator) Next() interface{} { 992 | _, value, ok := r.iter.Previous() 993 | if !ok { 994 | return nil 995 | } 996 | return value 997 | } 998 | 999 | func (r *radixReverseIterator) WatchCh() <-chan struct{} { 1000 | return r.watchCh 1001 | } 1002 | 1003 | // Snapshot creates a snapshot of the current state of the transaction. 1004 | // Returns a new read-only transaction or nil if the transaction is already 1005 | // aborted or committed. 1006 | func (txn *Txn) Snapshot() *Txn { 1007 | if txn.rootTxn == nil { 1008 | return nil 1009 | } 1010 | 1011 | snapshot := &Txn{ 1012 | db: txn.db, 1013 | rootTxn: txn.rootTxn.Clone(), 1014 | } 1015 | 1016 | // Commit sub-transactions into the snapshot 1017 | for key, subTxn := range txn.modified { 1018 | path := indexPath(key.Table, key.Index) 1019 | final := subTxn.CommitOnly() 1020 | snapshot.rootTxn.Insert(path, final) 1021 | } 1022 | 1023 | return snapshot 1024 | } 1025 | -------------------------------------------------------------------------------- /watch-gen/main.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | // This tool generates the special-case code for a small number of watchers 5 | // which runs all the watches in a single select vs. needing to spawn a 6 | // goroutine for each one. 7 | package main 8 | 9 | import ( 10 | "fmt" 11 | "os" 12 | "text/template" 13 | ) 14 | 15 | // aFew should be set to the number of channels to special-case for. Setting 16 | // this is a tradeoff for how big the slice is for the smallest watch set that 17 | // we see in practice vs. the number of goroutines we save when dealing with a 18 | // large number of watches. This was tuned with BenchmarkWatch to get setup 19 | // time for a watch with 1024 channels under 100 us on a 2.7 GHz Core i5. 20 | const aFew = 32 21 | 22 | // source is the template we use to generate the source file. 23 | const source = `package memdb 24 | 25 | //go:generate sh -c "go run watch-gen/main.go >watch_few.go" 26 | 27 | import( 28 | "context" 29 | ) 30 | 31 | // aFew gives how many watchers this function is wired to support. You must 32 | // always pass a full slice of this length, but unused channels can be nil. 33 | const aFew = {{len .}} 34 | 35 | // watchFew is used if there are only a few watchers as a performance 36 | // optimization. 37 | func watchFew(ctx context.Context, ch []<-chan struct{}) error { 38 | select { 39 | {{range $i, $unused := .}} 40 | case <-ch[{{printf "%d" $i}}]: 41 | return nil 42 | {{end}} 43 | case <-ctx.Done(): 44 | return ctx.Err() 45 | } 46 | } 47 | ` 48 | 49 | // render prints the template to stdout. 50 | func render() error { 51 | tmpl, err := template.New("watch").Parse(source) 52 | if err != nil { 53 | return err 54 | } 55 | if err := tmpl.Execute(os.Stdout, make([]struct{}, aFew)); err != nil { 56 | return err 57 | } 58 | return nil 59 | } 60 | 61 | func main() { 62 | if err := render(); err != nil { 63 | fmt.Fprintln(os.Stderr, err.Error()) 64 | os.Exit(1) 65 | } else { 66 | os.Exit(0) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /watch.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package memdb 5 | 6 | import ( 7 | "context" 8 | "time" 9 | ) 10 | 11 | // WatchSet is a collection of watch channels. The zero value is not usable. 12 | // Use NewWatchSet to create a WatchSet. 13 | type WatchSet map[<-chan struct{}]struct{} 14 | 15 | // NewWatchSet constructs a new watch set. 16 | func NewWatchSet() WatchSet { 17 | return make(map[<-chan struct{}]struct{}) 18 | } 19 | 20 | // Add appends a watchCh to the WatchSet if non-nil. 21 | func (w WatchSet) Add(watchCh <-chan struct{}) { 22 | if w == nil { 23 | return 24 | } 25 | 26 | if _, ok := w[watchCh]; !ok { 27 | w[watchCh] = struct{}{} 28 | } 29 | } 30 | 31 | // AddWithLimit appends a watchCh to the WatchSet if non-nil, and if the given 32 | // softLimit hasn't been exceeded. Otherwise, it will watch the given alternate 33 | // channel. It's expected that the altCh will be the same on many calls to this 34 | // function, so you will exceed the soft limit a little bit if you hit this, but 35 | // not by much. 36 | // 37 | // This is useful if you want to track individual items up to some limit, after 38 | // which you watch a higher-level channel (usually a channel from start of 39 | // an iterator higher up in the radix tree) that will watch a superset of items. 40 | func (w WatchSet) AddWithLimit(softLimit int, watchCh <-chan struct{}, altCh <-chan struct{}) { 41 | // This is safe for a nil WatchSet so we don't need to check that here. 42 | if len(w) < softLimit { 43 | w.Add(watchCh) 44 | } else { 45 | w.Add(altCh) 46 | } 47 | } 48 | 49 | // Watch blocks until one of the channels in the watch set is closed, or 50 | // timeoutCh sends a value. 51 | // Returns true if timeoutCh is what caused Watch to unblock. 52 | func (w WatchSet) Watch(timeoutCh <-chan time.Time) bool { 53 | if w == nil { 54 | return false 55 | } 56 | 57 | // Create a context that gets cancelled when the timeout is triggered 58 | ctx, cancel := context.WithCancel(context.Background()) 59 | defer cancel() 60 | 61 | go func() { 62 | select { 63 | case <-timeoutCh: 64 | cancel() 65 | case <-ctx.Done(): 66 | } 67 | }() 68 | 69 | return w.WatchCtx(ctx) == context.Canceled 70 | } 71 | 72 | // WatchCtx blocks until one of the channels in the watch set is closed, or 73 | // ctx is done (cancelled or exceeds the deadline). WatchCtx returns an error 74 | // if the ctx causes it to unblock, otherwise returns nil. 75 | // 76 | // WatchCtx should be preferred over Watch. 77 | func (w WatchSet) WatchCtx(ctx context.Context) error { 78 | if w == nil { 79 | return nil 80 | } 81 | 82 | if n := len(w); n <= aFew { 83 | idx := 0 84 | chunk := make([]<-chan struct{}, aFew) 85 | for watchCh := range w { 86 | chunk[idx] = watchCh 87 | idx++ 88 | } 89 | return watchFew(ctx, chunk) 90 | } 91 | 92 | return w.watchMany(ctx) 93 | } 94 | 95 | // watchMany is used if there are many watchers. 96 | func (w WatchSet) watchMany(ctx context.Context) error { 97 | // Cancel all watcher goroutines when return. 98 | watcherCtx, cancel := context.WithCancel(ctx) 99 | defer cancel() 100 | 101 | // Set up a goroutine for each watcher. 102 | triggerCh := make(chan struct{}, 1) 103 | watcher := func(chunk []<-chan struct{}) { 104 | if err := watchFew(watcherCtx, chunk); err == nil { 105 | select { 106 | case triggerCh <- struct{}{}: 107 | default: 108 | } 109 | } 110 | } 111 | 112 | // Apportion the watch channels into chunks we can feed into the 113 | // watchFew helper. 114 | idx := 0 115 | chunk := make([]<-chan struct{}, aFew) 116 | for watchCh := range w { 117 | subIdx := idx % aFew 118 | chunk[subIdx] = watchCh 119 | idx++ 120 | 121 | // Fire off this chunk and start a fresh one. 122 | if idx%aFew == 0 { 123 | go watcher(chunk) 124 | chunk = make([]<-chan struct{}, aFew) 125 | } 126 | } 127 | 128 | // Make sure to watch any residual channels in the last chunk. 129 | if idx%aFew != 0 { 130 | go watcher(chunk) 131 | } 132 | 133 | // Wait for a channel to trigger or timeout. 134 | select { 135 | case <-triggerCh: 136 | return nil 137 | case <-ctx.Done(): 138 | return ctx.Err() 139 | } 140 | } 141 | 142 | // WatchCh returns a channel that is used to wait for any channel of the watch set to trigger 143 | // or for the context to be cancelled. WatchCh creates a new goroutine each call, so 144 | // callers may need to cache the returned channel to avoid creating extra goroutines. 145 | func (w WatchSet) WatchCh(ctx context.Context) <-chan error { 146 | // Create the outgoing channel 147 | triggerCh := make(chan error, 1) 148 | 149 | // Create a goroutine to collect the error from WatchCtx 150 | go func() { 151 | triggerCh <- w.WatchCtx(ctx) 152 | }() 153 | 154 | return triggerCh 155 | } 156 | -------------------------------------------------------------------------------- /watch_few.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package memdb 5 | 6 | //go:generate sh -c "go run watch-gen/main.go >watch_few.go" 7 | 8 | import ( 9 | "context" 10 | ) 11 | 12 | // aFew gives how many watchers this function is wired to support. You must 13 | // always pass a full slice of this length, but unused channels can be nil. 14 | const aFew = 32 15 | 16 | // watchFew is used if there are only a few watchers as a performance 17 | // optimization. 18 | func watchFew(ctx context.Context, ch []<-chan struct{}) error { 19 | select { 20 | 21 | case <-ch[0]: 22 | return nil 23 | 24 | case <-ch[1]: 25 | return nil 26 | 27 | case <-ch[2]: 28 | return nil 29 | 30 | case <-ch[3]: 31 | return nil 32 | 33 | case <-ch[4]: 34 | return nil 35 | 36 | case <-ch[5]: 37 | return nil 38 | 39 | case <-ch[6]: 40 | return nil 41 | 42 | case <-ch[7]: 43 | return nil 44 | 45 | case <-ch[8]: 46 | return nil 47 | 48 | case <-ch[9]: 49 | return nil 50 | 51 | case <-ch[10]: 52 | return nil 53 | 54 | case <-ch[11]: 55 | return nil 56 | 57 | case <-ch[12]: 58 | return nil 59 | 60 | case <-ch[13]: 61 | return nil 62 | 63 | case <-ch[14]: 64 | return nil 65 | 66 | case <-ch[15]: 67 | return nil 68 | 69 | case <-ch[16]: 70 | return nil 71 | 72 | case <-ch[17]: 73 | return nil 74 | 75 | case <-ch[18]: 76 | return nil 77 | 78 | case <-ch[19]: 79 | return nil 80 | 81 | case <-ch[20]: 82 | return nil 83 | 84 | case <-ch[21]: 85 | return nil 86 | 87 | case <-ch[22]: 88 | return nil 89 | 90 | case <-ch[23]: 91 | return nil 92 | 93 | case <-ch[24]: 94 | return nil 95 | 96 | case <-ch[25]: 97 | return nil 98 | 99 | case <-ch[26]: 100 | return nil 101 | 102 | case <-ch[27]: 103 | return nil 104 | 105 | case <-ch[28]: 106 | return nil 107 | 108 | case <-ch[29]: 109 | return nil 110 | 111 | case <-ch[30]: 112 | return nil 113 | 114 | case <-ch[31]: 115 | return nil 116 | 117 | case <-ctx.Done(): 118 | return ctx.Err() 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /watch_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package memdb 5 | 6 | import ( 7 | "bytes" 8 | "context" 9 | "fmt" 10 | "runtime/pprof" 11 | "strings" 12 | "testing" 13 | "time" 14 | ) 15 | 16 | // testWatch makes a bunch of watch channels based on the given size and fires 17 | // the one at the given fire index to make sure it's detected (or a timeout 18 | // occurs if the fire index isn't hit). useCtx parameterizes whether the context 19 | // based watch is used or timer based. 20 | func testWatch(size, fire int, useCtx bool) error { 21 | shouldTimeout := true 22 | ws := NewWatchSet() 23 | for i := 0; i < size; i++ { 24 | watchCh := make(chan struct{}) 25 | ws.Add(watchCh) 26 | if fire == i { 27 | close(watchCh) 28 | shouldTimeout = false 29 | } 30 | } 31 | 32 | var timeoutCh chan time.Time 33 | var ctx context.Context 34 | var cancelFn context.CancelFunc 35 | if useCtx { 36 | ctx, cancelFn = context.WithCancel(context.Background()) 37 | defer cancelFn() 38 | } else { 39 | timeoutCh = make(chan time.Time) 40 | } 41 | 42 | doneCh := make(chan bool, 1) 43 | go func() { 44 | if useCtx { 45 | doneCh <- ws.WatchCtx(ctx) != nil 46 | } else { 47 | doneCh <- ws.Watch(timeoutCh) 48 | } 49 | }() 50 | 51 | if shouldTimeout { 52 | select { 53 | case <-doneCh: 54 | return fmt.Errorf("should not trigger") 55 | default: 56 | } 57 | 58 | if useCtx { 59 | cancelFn() 60 | } else { 61 | close(timeoutCh) 62 | } 63 | select { 64 | case didTimeout := <-doneCh: 65 | if !didTimeout { 66 | return fmt.Errorf("should have timed out") 67 | } 68 | case <-time.After(10 * time.Second): 69 | return fmt.Errorf("should have timed out") 70 | } 71 | } else { 72 | select { 73 | case didTimeout := <-doneCh: 74 | if didTimeout { 75 | return fmt.Errorf("should not have timed out") 76 | } 77 | case <-time.After(10 * time.Second): 78 | return fmt.Errorf("should have triggered") 79 | } 80 | if useCtx { 81 | cancelFn() 82 | } else { 83 | close(timeoutCh) 84 | } 85 | } 86 | return nil 87 | } 88 | 89 | func TestWatch(t *testing.T) { 90 | testFactory := func(useCtx bool) func(t *testing.T) { 91 | return func(t *testing.T) { 92 | // Sweep through a bunch of chunks to hit the various cases of dividing 93 | // the work into watchFew calls. 94 | for size := 0; size < 3*aFew; size++ { 95 | // Fire each possible channel slot. 96 | for fire := 0; fire < size; fire++ { 97 | if err := testWatch(size, fire, useCtx); err != nil { 98 | t.Fatalf("err %d %d: %v", size, fire, err) 99 | } 100 | } 101 | 102 | // Run a timeout case as well. 103 | fire := -1 104 | if err := testWatch(size, fire, useCtx); err != nil { 105 | t.Fatalf("err %d %d: %v", size, fire, err) 106 | } 107 | } 108 | } 109 | } 110 | 111 | t.Run("Timer", testFactory(false)) 112 | t.Run("Context", testFactory(true)) 113 | } 114 | 115 | func testWatchCh(size, fire int) error { 116 | shouldTimeout := true 117 | ws := NewWatchSet() 118 | for i := 0; i < size; i++ { 119 | watchCh := make(chan struct{}) 120 | ws.Add(watchCh) 121 | if fire == i { 122 | close(watchCh) 123 | shouldTimeout = false 124 | } 125 | } 126 | 127 | ctx, cancelFn := context.WithCancel(context.Background()) 128 | defer cancelFn() 129 | 130 | doneCh := make(chan bool, 1) 131 | go func() { 132 | err := <-ws.WatchCh(ctx) 133 | doneCh <- err != nil 134 | }() 135 | 136 | if shouldTimeout { 137 | select { 138 | case <-doneCh: 139 | return fmt.Errorf("should not trigger") 140 | default: 141 | } 142 | 143 | cancelFn() 144 | select { 145 | case didTimeout := <-doneCh: 146 | if !didTimeout { 147 | return fmt.Errorf("should have timed out") 148 | } 149 | case <-time.After(10 * time.Second): 150 | return fmt.Errorf("should have timed out") 151 | } 152 | } else { 153 | select { 154 | case didTimeout := <-doneCh: 155 | if didTimeout { 156 | return fmt.Errorf("should not have timed out") 157 | } 158 | case <-time.After(10 * time.Second): 159 | return fmt.Errorf("should have triggered") 160 | } 161 | cancelFn() 162 | } 163 | return nil 164 | } 165 | 166 | func TestWatchChan(t *testing.T) { 167 | 168 | // Sweep through a bunch of chunks to hit the various cases of dividing 169 | // the work into watchFew calls. 170 | for size := 0; size < 3*aFew; size++ { 171 | // Fire each possible channel slot. 172 | for fire := 0; fire < size; fire++ { 173 | if err := testWatchCh(size, fire); err != nil { 174 | t.Fatalf("err %d %d: %v", size, fire, err) 175 | } 176 | } 177 | 178 | // Run a timeout case as well. 179 | fire := -1 180 | if err := testWatchCh(size, fire); err != nil { 181 | t.Fatalf("err %d %d: %v", size, fire, err) 182 | } 183 | } 184 | } 185 | 186 | func TestWatch_AddWithLimit(t *testing.T) { 187 | // Make sure nil doesn't crash. 188 | { 189 | var ws WatchSet 190 | ch := make(chan struct{}) 191 | ws.AddWithLimit(10, ch, ch) 192 | } 193 | 194 | // Run a case where we trigger a channel that should be in 195 | // there. 196 | { 197 | ws := NewWatchSet() 198 | inCh := make(chan struct{}) 199 | altCh := make(chan struct{}) 200 | ws.AddWithLimit(1, inCh, altCh) 201 | 202 | nopeCh := make(chan struct{}) 203 | ws.AddWithLimit(1, nopeCh, altCh) 204 | 205 | close(inCh) 206 | didTimeout := ws.Watch(time.After(1 * time.Second)) 207 | if didTimeout { 208 | t.Fatalf("bad") 209 | } 210 | } 211 | 212 | // Run a case where we trigger the alt channel that should have 213 | // been added. 214 | { 215 | ws := NewWatchSet() 216 | inCh := make(chan struct{}) 217 | altCh := make(chan struct{}) 218 | ws.AddWithLimit(1, inCh, altCh) 219 | 220 | nopeCh := make(chan struct{}) 221 | ws.AddWithLimit(1, nopeCh, altCh) 222 | 223 | close(altCh) 224 | didTimeout := ws.Watch(time.After(1 * time.Second)) 225 | if didTimeout { 226 | t.Fatalf("bad") 227 | } 228 | } 229 | 230 | // Run a case where we trigger the nope channel that should not have 231 | // been added. 232 | { 233 | ws := NewWatchSet() 234 | inCh := make(chan struct{}) 235 | altCh := make(chan struct{}) 236 | ws.AddWithLimit(1, inCh, altCh) 237 | 238 | nopeCh := make(chan struct{}) 239 | ws.AddWithLimit(1, nopeCh, altCh) 240 | 241 | close(nopeCh) 242 | didTimeout := ws.Watch(time.After(1 * time.Second)) 243 | if !didTimeout { 244 | t.Fatalf("bad") 245 | } 246 | } 247 | } 248 | 249 | func TestWatchCtxLeak(t *testing.T) { 250 | ctx, cancel := context.WithCancel(context.Background()) 251 | defer cancel() 252 | 253 | // We add a large number of channels to a WatchSet then 254 | // call WatchCtx. If one of those channels fires, we 255 | // expect to see all the goroutines spawned by WatchCtx 256 | // cleaned up. 257 | pprof.Do(ctx, pprof.Labels("foo", "bar"), func(ctx context.Context) { 258 | ws := NewWatchSet() 259 | fireCh := make(chan struct{}) 260 | ws.Add(fireCh) 261 | for i := 0; i < 10000; i++ { 262 | watchCh := make(chan struct{}) 263 | ws.Add(watchCh) 264 | } 265 | result := make(chan error) 266 | go func() { 267 | result <- ws.WatchCtx(ctx) 268 | }() 269 | 270 | fireCh <- struct{}{} 271 | 272 | if err := <-result; err != nil { 273 | t.Fatalf("expected no err got: %v", err) 274 | } 275 | }) 276 | 277 | numRetries := 3 278 | var gced bool 279 | for i := 0; i < numRetries; i++ { 280 | var pb bytes.Buffer 281 | profiler := pprof.Lookup("goroutine") 282 | if profiler == nil { 283 | t.Fatal("unable to find profile") 284 | } 285 | err := profiler.WriteTo(&pb, 1) 286 | if err != nil { 287 | t.Fatalf("unable to read profile: %v", err) 288 | } 289 | // If the debug profile dump contains the string "foo", 290 | // it means one of the goroutines spawned in pprof.Do above 291 | // still appears in the capture. 292 | if !strings.Contains(pb.String(), "foo") { 293 | gced = true 294 | break 295 | } else { 296 | t.Log("retrying") 297 | time.Sleep(1 * time.Second) 298 | } 299 | } 300 | if !gced { 301 | t.Errorf("goroutines were not garbage collected after %d retries", numRetries) 302 | } 303 | } 304 | 305 | func BenchmarkWatch(b *testing.B) { 306 | ws := NewWatchSet() 307 | for i := 0; i < 1024; i++ { 308 | watchCh := make(chan struct{}) 309 | ws.Add(watchCh) 310 | } 311 | 312 | timeoutCh := make(chan time.Time) 313 | close(timeoutCh) 314 | 315 | b.ResetTimer() 316 | for i := 0; i < b.N; i++ { 317 | ws.Watch(timeoutCh) 318 | } 319 | } 320 | --------------------------------------------------------------------------------