├── .github └── workflows │ ├── main.yml │ └── pr.yml ├── CODEOWNERS ├── LICENSE ├── Makefile ├── README.md ├── any_table.go ├── benchmarks_test.go ├── cell.go ├── db.go ├── db_test.go ├── deletetracker.go ├── derive.go ├── derive_test.go ├── doc.go ├── errors.go ├── fuzz_test.go ├── go.mod ├── go.sum ├── graveyard.go ├── http.go ├── http_client.go ├── http_test.go ├── index ├── bool.go ├── int.go ├── keyset.go ├── keyset_test.go ├── map.go ├── netip.go ├── seq.go ├── seq_test.go ├── set.go └── string.go ├── internal ├── sortable_mutex.go ├── sortable_mutex_test.go └── time.go ├── iterator.go ├── iterator_test.go ├── metrics.go ├── observable.go ├── part ├── cache.go ├── iterator.go ├── map.go ├── map_test.go ├── node.go ├── ops.go ├── part_test.go ├── quick_test.go ├── registry.go ├── set.go ├── set_test.go ├── tree.go └── txn.go ├── quick_test.go ├── reconciler ├── benchmark │ ├── .gitignore │ ├── main.go │ └── run.sh ├── builder.go ├── config.go ├── example │ ├── .gitignore │ ├── main.go │ ├── ops.go │ └── types.go ├── helpers.go ├── incremental.go ├── index.go ├── metrics.go ├── multi_test.go ├── reconciler.go ├── retries.go ├── retries_test.go ├── script_test.go ├── status_test.go ├── testdata │ ├── batching.txtar │ ├── incremental.txtar │ ├── prune_empty.txtar │ ├── pruning.txtar │ └── refresh.txtar └── types.go ├── regression_test.go ├── script.go ├── script_test.go ├── table.go ├── testdata └── db.txtar ├── txn.go ├── types.go ├── watchset.go └── watchset_test.go /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: main 2 | on: 3 | push: 4 | branches: 5 | - main 6 | 7 | env: 8 | GO_VERSION: 1.23 9 | 10 | jobs: 11 | test: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/setup-go@v2 15 | with: 16 | go-version: ${{ env.GO_VERSION }} 17 | - uses: actions/checkout@v3 18 | with: 19 | fetch-depth: 0 20 | - name: test 21 | run: | 22 | make all 23 | 24 | -------------------------------------------------------------------------------- /.github/workflows/pr.yml: -------------------------------------------------------------------------------- 1 | name: pr 2 | on: 3 | pull_request: 4 | branches: 5 | - main 6 | 7 | env: 8 | GO_VERSION: 1.23 9 | 10 | jobs: 11 | test: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/setup-go@v2 15 | with: 16 | go-version: ${{ env.GO_VERSION }} 17 | - uses: actions/checkout@v3 18 | with: 19 | fetch-depth: 0 20 | 21 | - name: test 22 | run: | 23 | set -o pipefail 24 | echo '```' > results.comment 25 | echo "$ make" >> results.comment 26 | make 2>&1 | tee -a results.comment 27 | 28 | - name: close 29 | if: success() || failure() 30 | run: | 31 | echo '```' >> results.comment 32 | 33 | - name: results 34 | if: success() || failure() 35 | uses: thollander/actions-comment-pull-request@v2 36 | with: 37 | comment_tag: results 38 | filePath: results.comment 39 | 40 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Code owners groups assigned to this repository and a brief description of their areas: 2 | # @cilium/ci-structure Continuous integration, testing 3 | # @cilium/contributing Developer documentation & tools 4 | # @cilium/github-sec GitHub security (handling of secrets, consequences of pull_request_target, etc.) 5 | # @cilium/sig-foundations Core libraries and guidance to overall software architecture. 6 | 7 | # The following filepaths should be sorted so that more specific paths occur 8 | # after the less specific paths, otherwise the ownership for the specific paths 9 | # is not properly picked up in Github. 10 | * @cilium/sig-foundations 11 | /.github/workflows/ @cilium/github-sec @cilium/ci-structure @cilium/sig-foundations 12 | /CODEOWNERS @cilium/contributing @cilium/sig-foundations 13 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all build test test-race bench 2 | 3 | all: build test test-race bench 4 | 5 | build: 6 | go build ./... 7 | 8 | test: 9 | go test ./... -cover -vet=all -test.count 1 10 | 11 | test-race: 12 | go test -race ./... -test.count 1 13 | 14 | bench: 15 | go test ./... -bench . -benchmem -test.run xxx 16 | go run ./reconciler/benchmark -quiet 17 | 18 | bench-reconciler: 19 | go run ./reconciler/benchmark 20 | -------------------------------------------------------------------------------- /any_table.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package statedb 5 | 6 | import ( 7 | "fmt" 8 | "iter" 9 | ) 10 | 11 | // AnyTable allows any-typed access to a StateDB table. This is intended 12 | // for building generic tooling for accessing the table and should be 13 | // avoided if possible. 14 | type AnyTable struct { 15 | Meta TableMeta 16 | } 17 | 18 | func (t AnyTable) NumObjects(txn ReadTxn) int { 19 | indexTxn := txn.getTxn().mustIndexReadTxn(t.Meta, PrimaryIndexPos) 20 | return indexTxn.Len() 21 | } 22 | 23 | func (t AnyTable) All(txn ReadTxn) iter.Seq2[any, Revision] { 24 | all, _ := t.AllWatch(txn) 25 | return all 26 | } 27 | 28 | func (t AnyTable) AllWatch(txn ReadTxn) (iter.Seq2[any, Revision], <-chan struct{}) { 29 | indexTxn := txn.getTxn().mustIndexReadTxn(t.Meta, PrimaryIndexPos) 30 | return partSeq[any](indexTxn.Iterator()), indexTxn.RootWatch() 31 | } 32 | 33 | func (t AnyTable) UnmarshalYAML(data []byte) (any, error) { 34 | return t.Meta.unmarshalYAML(data) 35 | } 36 | 37 | func (t AnyTable) Insert(txn WriteTxn, obj any) (old any, hadOld bool, err error) { 38 | var iobj object 39 | iobj, hadOld, _, err = txn.getTxn().insert(t.Meta, Revision(0), obj) 40 | if hadOld { 41 | old = iobj.data 42 | } 43 | return 44 | } 45 | 46 | func (t AnyTable) Delete(txn WriteTxn, obj any) (old any, hadOld bool, err error) { 47 | var iobj object 48 | iobj, hadOld, err = txn.getTxn().delete(t.Meta, Revision(0), obj) 49 | if hadOld { 50 | old = iobj.data 51 | } 52 | return 53 | } 54 | 55 | func (t AnyTable) Get(txn ReadTxn, index string, key string) (any, Revision, bool, error) { 56 | itxn, rawKey, err := t.queryIndex(txn, index, key) 57 | if err != nil { 58 | return nil, 0, false, err 59 | } 60 | if itxn.unique { 61 | obj, _, ok := itxn.Get(rawKey) 62 | return obj.data, obj.revision, ok, nil 63 | } 64 | // For non-unique indexes we need to prefix search and make sure to fully 65 | // match the secondary key. 66 | iter, _ := itxn.Prefix(rawKey) 67 | for { 68 | k, obj, ok := iter.Next() 69 | if !ok { 70 | break 71 | } 72 | if nonUniqueKey(k).secondaryLen() == len(rawKey) { 73 | return obj.data, obj.revision, true, nil 74 | } 75 | } 76 | return nil, 0, false, nil 77 | } 78 | 79 | func (t AnyTable) Prefix(txn ReadTxn, index string, key string) (iter.Seq2[any, Revision], error) { 80 | itxn, rawKey, err := t.queryIndex(txn, index, key) 81 | if err != nil { 82 | return nil, err 83 | } 84 | iter, _ := itxn.Prefix(rawKey) 85 | if itxn.unique { 86 | return partSeq[any](iter), nil 87 | } 88 | return nonUniqueSeq[any](iter, true, rawKey), nil 89 | } 90 | 91 | func (t AnyTable) LowerBound(txn ReadTxn, index string, key string) (iter.Seq2[any, Revision], error) { 92 | itxn, rawKey, err := t.queryIndex(txn, index, key) 93 | if err != nil { 94 | return nil, err 95 | } 96 | iter := itxn.LowerBound(rawKey) 97 | if itxn.unique { 98 | return partSeq[any](iter), nil 99 | } 100 | return nonUniqueLowerBoundSeq[any](iter, rawKey), nil 101 | } 102 | 103 | func (t AnyTable) List(txn ReadTxn, index string, key string) (iter.Seq2[any, Revision], error) { 104 | itxn, rawKey, err := t.queryIndex(txn, index, key) 105 | if err != nil { 106 | return nil, err 107 | } 108 | iter, _ := itxn.Prefix(rawKey) 109 | if itxn.unique { 110 | // Unique index means that there can be only a single matching object. 111 | // Doing a Get() is more efficient than constructing an iterator. 112 | value, _, ok := itxn.Get(rawKey) 113 | return func(yield func(any, Revision) bool) { 114 | if ok { 115 | yield(value.data, value.revision) 116 | } 117 | }, nil 118 | } 119 | return nonUniqueSeq[any](iter, false, rawKey), nil 120 | } 121 | 122 | func (t AnyTable) queryIndex(txn ReadTxn, index string, key string) (indexReadTxn, []byte, error) { 123 | indexer := t.Meta.getIndexer(index) 124 | if indexer == nil { 125 | return indexReadTxn{}, nil, fmt.Errorf("invalid index %q", index) 126 | } 127 | rawKey, err := indexer.fromString(key) 128 | if err != nil { 129 | return indexReadTxn{}, nil, err 130 | } 131 | itxn, err := txn.getTxn().indexReadTxn(t.Meta, indexer.pos) 132 | return itxn, rawKey, err 133 | } 134 | 135 | func (t AnyTable) Changes(txn WriteTxn) (anyChangeIterator, error) { 136 | return t.Meta.anyChanges(txn) 137 | } 138 | 139 | func (t AnyTable) TableHeader() []string { 140 | zero := t.Meta.proto() 141 | if tw, ok := zero.(TableWritable); ok { 142 | return tw.TableHeader() 143 | } 144 | return nil 145 | } 146 | 147 | func (t AnyTable) Proto() any { 148 | return t.Meta.proto() 149 | } 150 | -------------------------------------------------------------------------------- /cell.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package statedb 5 | 6 | import ( 7 | "github.com/cilium/hive/cell" 8 | ) 9 | 10 | // This module provides an in-memory database built on top of immutable radix trees 11 | // As the database is based on an immutable data structure, the objects inserted into 12 | // the database MUST NOT be mutated, but rather copied first! 13 | var Cell = cell.Module( 14 | "statedb", 15 | "In-memory transactional database", 16 | 17 | cell.Provide( 18 | newHiveDB, 19 | ScriptCommands, 20 | ), 21 | ) 22 | 23 | type params struct { 24 | cell.In 25 | 26 | Lifecycle cell.Lifecycle 27 | Metrics Metrics `optional:"true"` 28 | } 29 | 30 | func newHiveDB(p params) *DB { 31 | db := New(WithMetrics(p.Metrics)) 32 | p.Lifecycle.Append( 33 | cell.Hook{ 34 | OnStart: func(cell.HookContext) error { 35 | return db.Start() 36 | }, 37 | OnStop: func(cell.HookContext) error { 38 | return db.Stop() 39 | }, 40 | }) 41 | return db 42 | } 43 | -------------------------------------------------------------------------------- /deletetracker.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package statedb 5 | 6 | import ( 7 | "sync/atomic" 8 | 9 | "github.com/cilium/statedb/index" 10 | ) 11 | 12 | type deleteTracker[Obj any] struct { 13 | db *DB 14 | trackerName string 15 | table Table[Obj] 16 | 17 | // revision is the last observed revision. Starts out at zero 18 | // in which case the garbage collector will not care about this 19 | // tracker when considering which objects to delete. 20 | revision atomic.Uint64 21 | } 22 | 23 | // setRevision is called to set the starting low watermark when 24 | // this deletion tracker is inserted into the table. 25 | func (dt *deleteTracker[Obj]) setRevision(rev uint64) { 26 | dt.revision.Store(rev) 27 | } 28 | 29 | // getRevision is called by the graveyard garbage collector to 30 | // compute the global low watermark. 31 | func (dt *deleteTracker[Obj]) getRevision() uint64 { 32 | return dt.revision.Load() 33 | } 34 | 35 | // Deleted returns an iterator for deleted objects in this table starting from 36 | // 'minRevision'. The deleted objects are not garbage-collected unless 'Mark' is 37 | // called! 38 | func (dt *deleteTracker[Obj]) deleted(txn *txn, minRevision Revision) Iterator[Obj] { 39 | indexEntry := txn.root[dt.table.tablePos()].indexes[GraveyardRevisionIndexPos] 40 | indexTxn := indexReadTxn{indexEntry.tree, indexEntry.unique} 41 | iter := indexTxn.LowerBound(index.Uint64(minRevision)) 42 | return &iterator[Obj]{iter} 43 | } 44 | 45 | // Mark the revision up to which deleted objects have been processed. This sets 46 | // the low watermark for deleted object garbage collection. 47 | func (dt *deleteTracker[Obj]) mark(upTo Revision) { 48 | // Store the new low watermark and trigger a round of garbage collection. 49 | dt.revision.Store(upTo) 50 | select { 51 | case dt.db.gcTrigger <- struct{}{}: 52 | default: 53 | } 54 | } 55 | 56 | func (dt *deleteTracker[Obj]) close() { 57 | if dt.db == nil { 58 | return 59 | } 60 | 61 | // Remove the delete tracker from the table. 62 | txn := dt.db.WriteTxn(dt.table).getTxn() 63 | dt.db = nil 64 | db := txn.db 65 | table := txn.modifiedTables[dt.table.tablePos()] 66 | if table == nil { 67 | panic("BUG: Table missing from write transaction") 68 | } 69 | _, _, table.deleteTrackers = table.deleteTrackers.Delete([]byte(dt.trackerName)) 70 | txn.Commit() 71 | 72 | db.metrics.DeleteTrackerCount(dt.table.Name(), table.deleteTrackers.Len()) 73 | 74 | // Trigger garbage collection without this delete tracker to garbage 75 | // collect any deleted objects that may not have been consumed. 76 | select { 77 | case db.gcTrigger <- struct{}{}: 78 | default: 79 | } 80 | 81 | } 82 | 83 | var closedWatchChannel = func() <-chan struct{} { 84 | ch := make(chan struct{}) 85 | close(ch) 86 | return ch 87 | }() 88 | -------------------------------------------------------------------------------- /derive.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package statedb 5 | 6 | import ( 7 | "context" 8 | 9 | "github.com/cilium/hive/cell" 10 | "github.com/cilium/hive/job" 11 | ) 12 | 13 | type DeriveResult int 14 | 15 | const ( 16 | DeriveInsert DeriveResult = 0 // Insert the object 17 | DeriveUpdate DeriveResult = 1 // Update the object (if it exists) 18 | DeriveDelete DeriveResult = 2 // Delete the object 19 | DeriveSkip DeriveResult = 3 // Skip 20 | ) 21 | 22 | type DeriveParams[In, Out any] struct { 23 | cell.In 24 | 25 | Lifecycle cell.Lifecycle 26 | Jobs job.Registry 27 | Health cell.Health 28 | DB *DB 29 | InTable Table[In] 30 | OutTable RWTable[Out] 31 | } 32 | 33 | // Derive constructs and registers a job to transform objects from the input table to the 34 | // output table, e.g. derive the output table from the input table. Useful when constructing 35 | // a reconciler that has its desired state solely derived from a single table. For example 36 | // the bandwidth manager's desired state is directly derived from the devices table. 37 | // 38 | // Derive is parametrized with the transform function that transforms the input object 39 | // into the output object. If the transform function returns false, then the object 40 | // is skipped. 41 | // 42 | // Example use: 43 | // 44 | // cell.Invoke( 45 | // statedb.Derive[*tables.Device, *Foo]( 46 | // func(d *Device, deleted bool) (*Foo, DeriveResult) { 47 | // if deleted { 48 | // return &Foo{Index: d.Index}, DeriveDelete 49 | // } 50 | // return &Foo{Index: d.Index}, DeriveInsert 51 | // }), 52 | // ) 53 | func Derive[In, Out any](jobName string, transform func(obj In, deleted bool) (Out, DeriveResult)) func(DeriveParams[In, Out]) { 54 | return func(p DeriveParams[In, Out]) { 55 | g := p.Jobs.NewGroup(p.Health, p.Lifecycle) 56 | g.Add(job.OneShot( 57 | jobName, 58 | derive[In, Out]{p, jobName, transform}.loop), 59 | ) 60 | } 61 | } 62 | 63 | type derive[In, Out any] struct { 64 | DeriveParams[In, Out] 65 | jobName string 66 | transform func(obj In, deleted bool) (Out, DeriveResult) 67 | } 68 | 69 | func (d derive[In, Out]) loop(ctx context.Context, _ cell.Health) error { 70 | out := d.OutTable 71 | txn := d.DB.WriteTxn(d.InTable) 72 | iter, err := d.InTable.Changes(txn) 73 | txn.Commit() 74 | if err != nil { 75 | return err 76 | } 77 | for { 78 | wtxn := d.DB.WriteTxn(out) 79 | changes, watch := iter.Next(wtxn) 80 | for change := range changes { 81 | outObj, result := d.transform(change.Object, change.Deleted) 82 | switch result { 83 | case DeriveInsert: 84 | _, _, err = out.Insert(wtxn, outObj) 85 | case DeriveUpdate: 86 | _, _, found := out.Get(wtxn, out.PrimaryIndexer().QueryFromObject(outObj)) 87 | if found { 88 | _, _, err = out.Insert(wtxn, outObj) 89 | } 90 | case DeriveDelete: 91 | _, _, err = out.Delete(wtxn, outObj) 92 | case DeriveSkip: 93 | } 94 | if err != nil { 95 | wtxn.Abort() 96 | return err 97 | } 98 | } 99 | wtxn.Commit() 100 | 101 | select { 102 | case <-watch: 103 | case <-ctx.Done(): 104 | return nil 105 | } 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /derive_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package statedb 5 | 6 | import ( 7 | "context" 8 | "log/slog" 9 | "slices" 10 | "testing" 11 | "time" 12 | 13 | "github.com/stretchr/testify/require" 14 | 15 | "github.com/cilium/hive" 16 | "github.com/cilium/hive/cell" 17 | "github.com/cilium/hive/hivetest" 18 | "github.com/cilium/hive/job" 19 | "github.com/cilium/statedb/index" 20 | "github.com/cilium/statedb/part" 21 | ) 22 | 23 | type derived struct { 24 | ID uint64 25 | Deleted bool 26 | } 27 | 28 | var derivedIdIndex = Index[derived, uint64]{ 29 | Name: "id", 30 | FromObject: func(t derived) index.KeySet { 31 | return index.NewKeySet(index.Uint64(t.ID)) 32 | }, 33 | FromKey: index.Uint64, 34 | Unique: true, 35 | } 36 | 37 | type nopHealth struct { 38 | } 39 | 40 | // Degraded implements cell.Health. 41 | func (*nopHealth) Degraded(reason string, err error) { 42 | } 43 | 44 | // NewScope implements cell.Health. 45 | func (h *nopHealth) NewScope(name string) cell.Health { 46 | return h 47 | } 48 | 49 | // OK implements cell.Health. 50 | func (*nopHealth) OK(status string) { 51 | } 52 | 53 | // Stopped implements cell.Health. 54 | func (*nopHealth) Stopped(reason string) { 55 | } 56 | 57 | func (*nopHealth) Close() {} 58 | 59 | func newNopHealth() (cell.Health, *nopHealth) { 60 | h := &nopHealth{} 61 | return h, h 62 | } 63 | 64 | var _ cell.Health = &nopHealth{} 65 | 66 | func TestDerive(t *testing.T) { 67 | var db *DB 68 | inTable, err := NewTable("test", idIndex) 69 | require.NoError(t, err) 70 | outTable, err := NewTable("derived", derivedIdIndex) 71 | require.NoError(t, err) 72 | 73 | transform := func(obj testObject, deleted bool) (derived, DeriveResult) { 74 | t.Logf("transform(%v, %v)", obj, deleted) 75 | 76 | tags := slices.Collect(obj.Tags.All()) 77 | if obj.Tags.Len() > 0 && tags[0] == "skip" { 78 | return derived{}, DeriveSkip 79 | } 80 | if deleted { 81 | if obj.Tags.Len() > 0 && tags[0] == "delete" { 82 | return derived{ID: obj.ID}, DeriveDelete 83 | } 84 | return derived{ID: obj.ID, Deleted: true}, DeriveUpdate 85 | } 86 | return derived{ID: obj.ID, Deleted: false}, DeriveInsert 87 | } 88 | 89 | h := hive.New( 90 | Cell, // DB 91 | job.Cell, 92 | cell.Provide(newNopHealth), 93 | cell.Module( 94 | "test", "Test", 95 | 96 | cell.Provide(func(db_ *DB) (Table[testObject], RWTable[derived], error) { 97 | db = db_ 98 | if err := db.RegisterTable(inTable); err != nil { 99 | return nil, nil, err 100 | } 101 | if err := db.RegisterTable(outTable); err != nil { 102 | return nil, nil, err 103 | } 104 | return inTable, outTable, nil 105 | }), 106 | 107 | cell.Invoke(Derive("testObject-to-derived", transform)), 108 | ), 109 | ) 110 | log := hivetest.Logger(t, hivetest.LogLevel(slog.LevelError)) 111 | require.NoError(t, h.Start(log, context.TODO()), "Start") 112 | 113 | getDerived := func() []derived { 114 | txn := db.ReadTxn() 115 | objs := Collect(outTable.All(txn)) 116 | // Log so we can trace the failed eventually calls 117 | t.Logf("derived: %+v", objs) 118 | return objs 119 | } 120 | 121 | // Insert 1, 2 and 3 (skipped) and validate. 122 | wtxn := db.WriteTxn(inTable) 123 | _, _, err = inTable.Insert(wtxn, testObject{ID: 1}) 124 | require.NoError(t, err, "Insert failed") 125 | _, _, err = inTable.Insert(wtxn, testObject{ID: 2}) 126 | require.NoError(t, err, "Insert failed") 127 | _, _, err = inTable.Insert(wtxn, testObject{ID: 3, Tags: part.NewSet("skip")}) 128 | require.NoError(t, err, "Insert failed") 129 | wtxn.Commit() 130 | 131 | require.Eventually(t, 132 | func() bool { 133 | objs := getDerived() 134 | return len(objs) == 2 && // 3 is skipped 135 | objs[0].ID == 1 && objs[1].ID == 2 136 | }, 137 | time.Second, 138 | 10*time.Millisecond, 139 | "expected 1 & 2 to be derived", 140 | ) 141 | 142 | // Delete 2 (testing DeriveUpdate) 143 | wtxn = db.WriteTxn(inTable) 144 | _, hadOld, err := inTable.Delete(wtxn, testObject{ID: 2}) 145 | require.NoError(t, err, "Delete failed") 146 | require.True(t, hadOld, "Expected object to be deleted") 147 | wtxn.Commit() 148 | 149 | require.Eventually(t, 150 | func() bool { 151 | objs := getDerived() 152 | return len(objs) == 2 && // 3 is skipped 153 | objs[0].ID == 1 && !objs[0].Deleted && 154 | objs[1].ID == 2 && objs[1].Deleted 155 | }, 156 | time.Second, 157 | 10*time.Millisecond, 158 | "expected 1 & 2, with 2 marked deleted", 159 | ) 160 | 161 | // Delete 1 (testing DeriveDelete) 162 | wtxn = db.WriteTxn(inTable) 163 | _, _, err = inTable.Insert(wtxn, testObject{ID: 1, Tags: part.NewSet("delete")}) 164 | require.NoError(t, err, "Insert failed") 165 | wtxn.Commit() 166 | wtxn = db.WriteTxn(inTable) 167 | _, _, err = inTable.Delete(wtxn, testObject{ID: 1}) 168 | require.NoError(t, err, "Delete failed") 169 | wtxn.Commit() 170 | 171 | require.Eventually(t, 172 | func() bool { 173 | objs := getDerived() 174 | return len(objs) == 1 && 175 | objs[0].ID == 2 && objs[0].Deleted 176 | }, 177 | time.Second, 178 | 10*time.Millisecond, 179 | "expected 1 to be gone, and 2 mark deleted", 180 | ) 181 | 182 | require.NoError(t, h.Stop(log, context.TODO()), "Stop") 183 | } 184 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | // The statedb package provides a transactional in-memory database with per-table locking. 5 | // The database indexes objects using Persistive Adaptive Radix Trees. 6 | // (https://db.in.tum.de/~leis/papers/ART.pdf) 7 | // 8 | // As this is built around an immutable data structure and objects may have lockless readers 9 | // the stored objects MUST NOT be mutated, but instead a copy must be made prior to mutation 10 | // and insertion. 11 | // 12 | // See 'example/' for an example how to construct an application that uses this library. 13 | package statedb 14 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package statedb 5 | 6 | import ( 7 | "errors" 8 | "fmt" 9 | ) 10 | 11 | var ( 12 | // ErrDuplicateTable indicates that StateDB has been provided with two or more table definitions 13 | // that share the same table name. 14 | ErrDuplicateTable = errors.New("table already exists") 15 | 16 | // ErrTableNotRegistered indicates that a user tries to write to a table that has not been 17 | // registered with this StateDB instance. 18 | ErrTableNotRegistered = errors.New("table not registered") 19 | 20 | // ErrPrimaryIndexNotUnique indicates that the primary index for the table is not marked unique. 21 | ErrPrimaryIndexNotUnique = errors.New("primary index not unique") 22 | 23 | // ErrDuplicateIndex indicates that the table has two or more indexers that share the same name. 24 | ErrDuplicateIndex = errors.New("index name already in use") 25 | 26 | // ErrReservedPrefix indicates that the index name is using the reserved prefix and should 27 | // be renamed. 28 | ErrReservedPrefix = errors.New("index name uses reserved prefix '" + reservedIndexPrefix + "'") 29 | 30 | // ErrTransactionClosed indicates that a write operation is performed using a transaction 31 | // that has already been committed or aborted. 32 | ErrTransactionClosed = errors.New("transaction is closed") 33 | 34 | // ErrTableNotLockedForWriting indicates that a write operation is performed against a 35 | // table that was not locked for writing, e.g. target table not given as argument to 36 | // WriteTxn(). 37 | ErrTableNotLockedForWriting = errors.New("not locked for writing") 38 | 39 | // ErrRevisionNotEqual indicates that the CompareAndSwap or CompareAndDelete failed due to 40 | // the object having a mismatching revision, e.g. it had been changed since the object 41 | // was last read. 42 | ErrRevisionNotEqual = errors.New("revision not equal") 43 | 44 | // ErrObjectNotFound indicates that the object was not found when the operation required 45 | // it to exists. This error is not returned by Insert or Delete, but may be returned by 46 | // CompareAndSwap or CompareAndDelete. 47 | ErrObjectNotFound = errors.New("object not found") 48 | ) 49 | 50 | // tableError wraps an error with the table name. 51 | func tableError(tableName string, err error) error { 52 | return fmt.Errorf("table %q: %w", tableName, err) 53 | } 54 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/cilium/statedb 2 | 3 | go 1.23 4 | 5 | require ( 6 | github.com/cilium/hive v0.0.0-20250522123230-2946c4940f41 7 | github.com/cilium/stream v0.0.0-20240209152734-a0792b51812d 8 | github.com/liggitt/tabwriter v0.0.0-20181228230101-89fcab3d43de 9 | github.com/spf13/cobra v1.8.0 10 | github.com/spf13/pflag v1.0.5 11 | github.com/stretchr/testify v1.8.4 12 | go.uber.org/goleak v1.3.0 13 | golang.org/x/term v0.16.0 14 | golang.org/x/time v0.5.0 15 | gopkg.in/yaml.v3 v3.0.1 16 | ) 17 | 18 | require ( 19 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 20 | github.com/fsnotify/fsnotify v1.7.0 // indirect 21 | github.com/go-viper/mapstructure/v2 v2.2.1 // indirect 22 | github.com/hashicorp/hcl v1.0.0 // indirect 23 | github.com/inconshreveable/mousetrap v1.1.0 // indirect 24 | github.com/magiconair/properties v1.8.7 // indirect 25 | github.com/mitchellh/mapstructure v1.5.0 // indirect 26 | github.com/pelletier/go-toml/v2 v2.1.0 // indirect 27 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 28 | github.com/sagikazarmark/locafero v0.4.0 // indirect 29 | github.com/sagikazarmark/slog-shim v0.1.0 // indirect 30 | github.com/sourcegraph/conc v0.3.0 // indirect 31 | github.com/spf13/afero v1.11.0 // indirect 32 | github.com/spf13/cast v1.6.0 // indirect 33 | github.com/spf13/viper v1.18.2 // indirect 34 | github.com/subosito/gotenv v1.6.0 // indirect 35 | go.uber.org/dig v1.17.1 // indirect 36 | go.uber.org/multierr v1.11.0 // indirect 37 | golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect 38 | golang.org/x/sys v0.17.0 // indirect 39 | golang.org/x/text v0.14.0 // indirect 40 | golang.org/x/tools v0.17.0 // indirect 41 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect 42 | gopkg.in/ini.v1 v1.67.0 // indirect 43 | ) 44 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/cilium/hive v0.0.0-20241213101835-553aca42f74a h1:KuDVdRWFhuntkXMuXBraKvsJ4o6HuPf3iF2hETefRtE= 2 | github.com/cilium/hive v0.0.0-20241213101835-553aca42f74a/go.mod h1:pI2GJ1n3SLKIQVFrKF7W6A6gb6BQkZ+3Hp4PAEo5SuI= 3 | github.com/cilium/hive v0.0.0-20241213121623-605c1412b9b3 h1:RfmUH1ouzj0LzORYJRhp43e1rlGpx6GNv4NIRUakU2w= 4 | github.com/cilium/hive v0.0.0-20241213121623-605c1412b9b3/go.mod h1:pI2GJ1n3SLKIQVFrKF7W6A6gb6BQkZ+3Hp4PAEo5SuI= 5 | github.com/cilium/hive v0.0.0-20250409150907-8eacab6fab5b h1:00k4EwXiIZ2J3cMt0xvawCqYu/oULGIpKta3o/U8CgE= 6 | github.com/cilium/hive v0.0.0-20250409150907-8eacab6fab5b/go.mod h1:3DSFWuYTjYtWkanf84uwMgDvP1pjJ313zXJxMfkz/Eg= 7 | github.com/cilium/hive v0.0.0-20250522123230-2946c4940f41 h1:H3y9UKJpLlMhkuu+008rotTY52gGIE8U2wEUJKp+hoc= 8 | github.com/cilium/hive v0.0.0-20250522123230-2946c4940f41/go.mod h1:3DSFWuYTjYtWkanf84uwMgDvP1pjJ313zXJxMfkz/Eg= 9 | github.com/cilium/stream v0.0.0-20240209152734-a0792b51812d h1:p6MgATaKEB9o7iAsk9rlzXNDMNCeKPAkx4Y8f+Zq8X8= 10 | github.com/cilium/stream v0.0.0-20240209152734-a0792b51812d/go.mod h1:3VLiLgs8wfjirkuYqos4t0IBPQ+sXtf3tFkChLm6ARM= 11 | github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= 12 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 13 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 14 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= 15 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 16 | github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= 17 | github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= 18 | github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= 19 | github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= 20 | github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss= 21 | github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= 22 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 23 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 24 | github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= 25 | github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= 26 | github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= 27 | github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= 28 | github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= 29 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 30 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 31 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 32 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 33 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 34 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 35 | github.com/liggitt/tabwriter v0.0.0-20181228230101-89fcab3d43de h1:9TO3cAIGXtEhnIaL+V+BEER86oLrvS+kWobKpbJuye0= 36 | github.com/liggitt/tabwriter v0.0.0-20181228230101-89fcab3d43de/go.mod h1:zAbeS9B/r2mtpb6U+EI2rYA5OAXxsYw6wTamcNW+zcE= 37 | github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= 38 | github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= 39 | github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= 40 | github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= 41 | github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4= 42 | github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= 43 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 44 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= 45 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 46 | github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= 47 | github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= 48 | github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= 49 | github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= 50 | github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= 51 | github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= 52 | github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= 53 | github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= 54 | github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= 55 | github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= 56 | github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= 57 | github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= 58 | github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= 59 | github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= 60 | github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= 61 | github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= 62 | github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= 63 | github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= 64 | github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMVB+yk= 65 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 66 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 67 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 68 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 69 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 70 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 71 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 72 | github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= 73 | github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= 74 | go.uber.org/dig v1.17.1 h1:Tga8Lz8PcYNsWsyHMZ1Vm0OQOUaJNDyvPImgbAu9YSc= 75 | go.uber.org/dig v1.17.1/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE= 76 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= 77 | go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= 78 | go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= 79 | go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= 80 | golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= 81 | golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= 82 | golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= 83 | golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 84 | golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE= 85 | golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= 86 | golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= 87 | golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= 88 | golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= 89 | golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= 90 | golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc= 91 | golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= 92 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 93 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 94 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 95 | gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= 96 | gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= 97 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 98 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 99 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 100 | -------------------------------------------------------------------------------- /graveyard.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package statedb 5 | 6 | import ( 7 | "context" 8 | "maps" 9 | "slices" 10 | "time" 11 | 12 | "golang.org/x/time/rate" 13 | ) 14 | 15 | const ( 16 | // defaultGCRateLimitInterval is the default minimum interval between garbage collections. 17 | defaultGCRateLimitInterval = time.Second 18 | ) 19 | 20 | func graveyardWorker(db *DB, ctx context.Context, gcRateLimitInterval time.Duration) { 21 | limiter := rate.NewLimiter(rate.Every(gcRateLimitInterval), 1) 22 | defer close(db.gcExited) 23 | 24 | for { 25 | select { 26 | case <-ctx.Done(): 27 | return 28 | case <-db.gcTrigger: 29 | } 30 | 31 | // Throttle garbage collection. 32 | if err := limiter.Wait(ctx); err != nil { 33 | return 34 | } 35 | 36 | cleaningTimes := make(map[string]time.Duration) 37 | 38 | type deadObjectRevisionKey = []byte 39 | toBeDeleted := map[TableMeta][]deadObjectRevisionKey{} 40 | 41 | // Do a lockless read transaction to find potential dead objects. 42 | txn := db.ReadTxn().getTxn() 43 | for _, table := range txn.root { 44 | tableName := table.meta.Name() 45 | start := time.Now() 46 | 47 | // Find the low watermark 48 | lowWatermark := table.revision 49 | dtIter := table.deleteTrackers.Iterator() 50 | for _, dt, ok := dtIter.Next(); ok; _, dt, ok = dtIter.Next() { 51 | rev := dt.getRevision() 52 | if rev < lowWatermark { 53 | lowWatermark = rev 54 | } 55 | } 56 | 57 | db.metrics.GraveyardLowWatermark( 58 | tableName, 59 | lowWatermark, 60 | ) 61 | 62 | // Find objects to be deleted by iterating over the graveyard revision index up 63 | // to the low watermark. 64 | indexTree := txn.mustIndexReadTxn(table.meta, GraveyardRevisionIndexPos) 65 | 66 | objIter := indexTree.Iterator() 67 | for key, obj, ok := objIter.Next(); ok; key, obj, ok = objIter.Next() { 68 | if obj.revision > lowWatermark { 69 | break 70 | } 71 | toBeDeleted[table.meta] = append(toBeDeleted[table.meta], key) 72 | } 73 | cleaningTimes[tableName] = time.Since(start) 74 | } 75 | 76 | if len(toBeDeleted) == 0 { 77 | for tableName, stat := range cleaningTimes { 78 | db.metrics.GraveyardCleaningDuration( 79 | tableName, 80 | stat, 81 | ) 82 | } 83 | continue 84 | } 85 | 86 | // Dead objects found, do a write transaction against all tables with dead objects in them. 87 | tablesToModify := slices.Collect(maps.Keys(toBeDeleted)) 88 | txn = db.WriteTxn(tablesToModify[0], tablesToModify[1:]...).getTxn() 89 | for meta, deadObjs := range toBeDeleted { 90 | tableName := meta.Name() 91 | start := time.Now() 92 | for _, key := range deadObjs { 93 | oldObj, existed := txn.mustIndexWriteTxn(meta, GraveyardRevisionIndexPos).Delete(key) 94 | if existed { 95 | // The dead object still existed (and wasn't replaced by a create->delete), 96 | // delete it from the primary index. 97 | key = meta.primary().fromObject(oldObj).First() 98 | txn.mustIndexWriteTxn(meta, GraveyardIndexPos).Delete(key) 99 | } 100 | } 101 | cleaningTimes[tableName] = time.Since(start) 102 | } 103 | txn.Commit() 104 | 105 | for tableName, stat := range cleaningTimes { 106 | db.metrics.GraveyardCleaningDuration( 107 | tableName, 108 | stat, 109 | ) 110 | } 111 | 112 | // Update object count metrics. 113 | txn = db.ReadTxn().getTxn() 114 | for _, table := range txn.root { 115 | name := table.meta.Name() 116 | db.metrics.GraveyardObjectCount(string(name), table.numDeletedObjects()) 117 | db.metrics.ObjectCount(string(name), table.numObjects()) 118 | } 119 | } 120 | } 121 | 122 | // graveyardIsEmpty returns true if no objects exist in the graveyard of any table. 123 | // Used in tests. 124 | func (db *DB) graveyardIsEmpty() bool { 125 | txn := db.ReadTxn().getTxn() 126 | for _, table := range txn.root { 127 | indexEntry := table.indexes[table.meta.indexPos(GraveyardIndex)] 128 | if indexEntry.tree.Len() != 0 { 129 | return false 130 | } 131 | } 132 | return true 133 | } 134 | -------------------------------------------------------------------------------- /http.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package statedb 5 | 6 | import ( 7 | "encoding/base64" 8 | "encoding/json" 9 | "fmt" 10 | "io" 11 | "net/http" 12 | "time" 13 | 14 | "github.com/cilium/statedb/part" 15 | ) 16 | 17 | func (db *DB) HTTPHandler() http.Handler { 18 | h := dbHandler{db} 19 | mux := http.NewServeMux() 20 | mux.HandleFunc("GET /dump", h.dumpAll) 21 | mux.HandleFunc("GET /dump/{table}", h.dumpTable) 22 | mux.HandleFunc("GET /query", h.query) 23 | mux.HandleFunc("GET /changes/{table}", h.changes) 24 | return mux 25 | } 26 | 27 | type dbHandler struct { 28 | db *DB 29 | } 30 | 31 | func (h dbHandler) dumpAll(w http.ResponseWriter, r *http.Request) { 32 | w.Header().Add("Content-Type", "application/json") 33 | w.WriteHeader(http.StatusOK) 34 | h.db.ReadTxn().WriteJSON(w) 35 | } 36 | 37 | func (h dbHandler) dumpTable(w http.ResponseWriter, r *http.Request) { 38 | w.Header().Add("Content-Type", "application/json") 39 | w.WriteHeader(http.StatusOK) 40 | 41 | var err error 42 | if table := r.PathValue("table"); table != "" { 43 | err = h.db.ReadTxn().WriteJSON(w, r.PathValue("table")) 44 | } else { 45 | err = h.db.ReadTxn().WriteJSON(w) 46 | } 47 | if err != nil { 48 | panic(err) 49 | } 50 | } 51 | 52 | func (h dbHandler) query(w http.ResponseWriter, r *http.Request) { 53 | enc := json.NewEncoder(w) 54 | 55 | var req QueryRequest 56 | body, err := io.ReadAll(r.Body) 57 | r.Body.Close() 58 | if err != nil { 59 | w.WriteHeader(http.StatusBadRequest) 60 | enc.Encode(QueryResponse{Err: err.Error()}) 61 | return 62 | } 63 | 64 | if err := json.Unmarshal(body, &req); err != nil { 65 | w.WriteHeader(http.StatusBadRequest) 66 | enc.Encode(QueryResponse{Err: err.Error()}) 67 | return 68 | } 69 | 70 | queryKey, err := base64.StdEncoding.DecodeString(req.Key) 71 | if err != nil { 72 | w.WriteHeader(http.StatusBadRequest) 73 | enc.Encode(QueryResponse{Err: err.Error()}) 74 | return 75 | } 76 | 77 | txn := h.db.ReadTxn().getTxn() 78 | 79 | // Look up the table 80 | var table TableMeta 81 | for _, e := range txn.root { 82 | if e.meta.Name() == req.Table { 83 | table = e.meta 84 | break 85 | } 86 | } 87 | if table == nil { 88 | w.WriteHeader(http.StatusNotFound) 89 | enc.Encode(QueryResponse{Err: fmt.Sprintf("Table %q not found", req.Table)}) 90 | return 91 | } 92 | 93 | indexPos := table.indexPos(req.Index) 94 | 95 | indexTxn, err := txn.indexReadTxn(table, indexPos) 96 | if err != nil { 97 | w.WriteHeader(http.StatusBadRequest) 98 | enc.Encode(QueryResponse{Err: err.Error()}) 99 | return 100 | } 101 | 102 | w.WriteHeader(http.StatusOK) 103 | onObject := func(obj object) error { 104 | return enc.Encode(QueryResponse{ 105 | Rev: obj.revision, 106 | Obj: obj.data, 107 | }) 108 | } 109 | runQuery(indexTxn, req.LowerBound, queryKey, onObject) 110 | } 111 | 112 | type QueryRequest struct { 113 | Key string `json:"key"` // Base64 encoded query key 114 | Table string `json:"table"` 115 | Index string `json:"index"` 116 | LowerBound bool `json:"lowerbound"` 117 | } 118 | 119 | type QueryResponse struct { 120 | Rev uint64 `json:"rev"` 121 | Obj any `json:"obj"` 122 | Err string `json:"err,omitempty"` 123 | } 124 | 125 | func runQuery(indexTxn indexReadTxn, lowerbound bool, queryKey []byte, onObject func(object) error) { 126 | var iter *part.Iterator[object] 127 | if !indexTxn.unique { 128 | queryKey = encodeNonUniqueBytes(queryKey) 129 | } 130 | if lowerbound { 131 | iter = indexTxn.LowerBound(queryKey) 132 | } else { 133 | iter, _ = indexTxn.Prefix(queryKey) 134 | } 135 | var match func([]byte) bool 136 | switch { 137 | case lowerbound: 138 | match = func([]byte) bool { return true } 139 | case indexTxn.unique: 140 | match = func(k []byte) bool { return len(k) == len(queryKey) } 141 | default: 142 | match = func(k []byte) bool { 143 | return nonUniqueKey(k).secondaryLen() == len(queryKey) 144 | } 145 | } 146 | for key, obj, ok := iter.Next(); ok; key, obj, ok = iter.Next() { 147 | if !match(key) { 148 | continue 149 | } 150 | if err := onObject(obj); err != nil { 151 | panic(err) 152 | } 153 | } 154 | } 155 | 156 | func (h dbHandler) changes(w http.ResponseWriter, r *http.Request) { 157 | const keepaliveInterval = 30 * time.Second 158 | 159 | enc := json.NewEncoder(w) 160 | tableName := r.PathValue("table") 161 | 162 | // Look up the table 163 | var tableMeta TableMeta 164 | for _, e := range h.db.ReadTxn().getTxn().root { 165 | if e.meta.Name() == tableName { 166 | tableMeta = e.meta 167 | break 168 | } 169 | } 170 | if tableMeta == nil { 171 | w.WriteHeader(http.StatusNotFound) 172 | enc.Encode(QueryResponse{Err: fmt.Sprintf("Table %q not found", tableName)}) 173 | return 174 | } 175 | 176 | // Register for changes. 177 | wtxn := h.db.WriteTxn(tableMeta) 178 | changeIter, err := tableMeta.anyChanges(wtxn) 179 | wtxn.Commit() 180 | if err != nil { 181 | w.WriteHeader(http.StatusInternalServerError) 182 | return 183 | } 184 | 185 | w.WriteHeader(http.StatusOK) 186 | 187 | ticker := time.NewTicker(keepaliveInterval) 188 | defer ticker.Stop() 189 | 190 | for { 191 | changes, watch := changeIter.nextAny(h.db.ReadTxn()) 192 | for change := range changes { 193 | err := enc.Encode(change) 194 | if err != nil { 195 | panic(err) 196 | } 197 | } 198 | w.(http.Flusher).Flush() 199 | select { 200 | case <-r.Context().Done(): 201 | return 202 | 203 | case <-ticker.C: 204 | // Send an empty keep-alive 205 | enc.Encode(Change[any]{}) 206 | 207 | case <-watch: 208 | } 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /http_client.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package statedb 5 | 6 | import ( 7 | "bytes" 8 | "context" 9 | "encoding/base64" 10 | "encoding/json" 11 | "errors" 12 | "fmt" 13 | "io" 14 | "iter" 15 | "net/http" 16 | "net/url" 17 | ) 18 | 19 | // NewRemoteTable creates a new handle for querying a remote StateDB table over the HTTP. 20 | // Example usage: 21 | // 22 | // devices := statedb.NewRemoteTable[*tables.Device](url.Parse("http://localhost:8080/db"), "devices") 23 | // 24 | // // Get all devices ordered by name. 25 | // iter, errs := devices.LowerBound(ctx, tables.DeviceByName("")) 26 | // for device, revision, ok := iter.Next(); ok; device, revision, ok = iter.Next() { ... } 27 | // 28 | // // Get device by name. 29 | // iter, errs := devices.Get(ctx, tables.DeviceByName("eth0")) 30 | // if dev, revision, ok := iter.Next(); ok { ... } 31 | // 32 | // // Get devices in revision order, e.g. oldest changed devices first. 33 | // iter, errs = devices.LowerBound(ctx, statedb.ByRevision(0)) 34 | func NewRemoteTable[Obj any](base *url.URL, table TableName) *RemoteTable[Obj] { 35 | return &RemoteTable[Obj]{base: base, tableName: table} 36 | } 37 | 38 | type RemoteTable[Obj any] struct { 39 | client http.Client 40 | base *url.URL 41 | tableName TableName 42 | } 43 | 44 | func (t *RemoteTable[Obj]) SetTransport(tr *http.Transport) { 45 | t.client.Transport = tr 46 | } 47 | 48 | func (t *RemoteTable[Obj]) query(ctx context.Context, lowerBound bool, q Query[Obj]) (seq iter.Seq2[Obj, Revision], errChan <-chan error) { 49 | // Use a channel to return errors so we can use the same Iterator[Obj] interface as StateDB does. 50 | errChanSend := make(chan error, 1) 51 | errChan = errChanSend 52 | 53 | key := base64.StdEncoding.EncodeToString(q.key) 54 | queryReq := QueryRequest{ 55 | Key: key, 56 | Table: t.tableName, 57 | Index: q.index, 58 | LowerBound: lowerBound, 59 | } 60 | bs, err := json.Marshal(&queryReq) 61 | if err != nil { 62 | errChanSend <- err 63 | return 64 | } 65 | 66 | url := t.base.JoinPath("/query") 67 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), bytes.NewBuffer(bs)) 68 | if err != nil { 69 | errChanSend <- err 70 | return 71 | } 72 | req.Header.Add("Content-Type", "application/json") 73 | req.Header.Add("Accept", "application/json") 74 | 75 | resp, err := t.client.Do(req) 76 | if err != nil { 77 | errChanSend <- err 78 | return 79 | } 80 | return remoteGetSeq[Obj](json.NewDecoder(resp.Body), errChanSend), errChan 81 | } 82 | 83 | func (t *RemoteTable[Obj]) Get(ctx context.Context, q Query[Obj]) (iter.Seq2[Obj, Revision], <-chan error) { 84 | return t.query(ctx, false, q) 85 | } 86 | 87 | func (t *RemoteTable[Obj]) LowerBound(ctx context.Context, q Query[Obj]) (iter.Seq2[Obj, Revision], <-chan error) { 88 | return t.query(ctx, true, q) 89 | } 90 | 91 | // responseObject is a typed counterpart of [queryResponseObject] 92 | type responseObject[Obj any] struct { 93 | Rev uint64 `json:"rev"` 94 | Obj Obj `json:"obj"` 95 | Err string `json:"err,omitempty"` 96 | } 97 | 98 | func remoteGetSeq[Obj any](dec *json.Decoder, errChan chan error) iter.Seq2[Obj, Revision] { 99 | return func(yield func(Obj, Revision) bool) { 100 | for { 101 | var resp responseObject[Obj] 102 | err := dec.Decode(&resp) 103 | errString := "" 104 | if err != nil { 105 | if errors.Is(err, io.EOF) { 106 | close(errChan) 107 | break 108 | } 109 | errString = "Decode error: " + err.Error() 110 | } else { 111 | errString = resp.Err 112 | } 113 | if errString != "" { 114 | errChan <- errors.New(errString) 115 | break 116 | } 117 | if !yield(resp.Obj, resp.Rev) { 118 | break 119 | } 120 | } 121 | } 122 | } 123 | 124 | func (t *RemoteTable[Obj]) Changes(ctx context.Context) (seq iter.Seq2[Change[Obj], Revision], errChan <-chan error) { 125 | // Use a channel to return errors so we can use the same Iterator[Obj] interface as StateDB does. 126 | errChanSend := make(chan error, 1) 127 | errChan = errChanSend 128 | 129 | url := t.base.JoinPath("/changes", t.tableName) 130 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil) 131 | if err != nil { 132 | errChanSend <- err 133 | close(errChanSend) 134 | return 135 | } 136 | 137 | req.Header.Add("Content-Type", "application/json") 138 | req.Header.Add("Accept", "application/json") 139 | 140 | resp, err := t.client.Do(req) 141 | if err != nil { 142 | errChanSend <- err 143 | close(errChanSend) 144 | return 145 | } 146 | return remoteChangeSeq[Obj](json.NewDecoder(resp.Body), errChanSend), errChan 147 | } 148 | 149 | func remoteChangeSeq[Obj any](dec *json.Decoder, errChan chan error) iter.Seq2[Change[Obj], Revision] { 150 | return func(yield func(Change[Obj], Revision) bool) { 151 | defer close(errChan) 152 | for { 153 | var change Change[Obj] 154 | err := dec.Decode(&change) 155 | if err == nil && change.Revision == 0 { 156 | // Keep-alive message, skip it. 157 | continue 158 | } 159 | 160 | if err != nil { 161 | if !errors.Is(err, io.EOF) { 162 | errChan <- fmt.Errorf("decode error: %w", err) 163 | } 164 | return 165 | } 166 | 167 | if !yield(change, change.Revision) { 168 | return 169 | } 170 | } 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /http_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package statedb 5 | 6 | import ( 7 | "context" 8 | "encoding/json" 9 | "io" 10 | "net/http" 11 | "net/http/httptest" 12 | "net/url" 13 | "testing" 14 | 15 | "github.com/stretchr/testify/assert" 16 | "github.com/stretchr/testify/require" 17 | 18 | "github.com/cilium/statedb/index" 19 | "github.com/cilium/statedb/part" 20 | ) 21 | 22 | func httpFixture(t *testing.T) (*DB, RWTable[testObject], *httptest.Server) { 23 | db, table, _ := newTestDB(t, tagsIndex) 24 | 25 | ts := httptest.NewServer(db.HTTPHandler()) 26 | t.Cleanup(ts.Close) 27 | 28 | wtxn := db.WriteTxn(table) 29 | table.Insert(wtxn, testObject{1, part.NewSet("foo")}) 30 | table.Insert(wtxn, testObject{2, part.NewSet("foo")}) 31 | table.Insert(wtxn, testObject{3, part.NewSet("foobar")}) 32 | table.Insert(wtxn, testObject{4, part.NewSet("baz")}) 33 | wtxn.Commit() 34 | 35 | return db, table, ts 36 | } 37 | 38 | func Test_http_dump(t *testing.T) { 39 | db, tbl, ts := httpFixture(t) 40 | 41 | resp, err := http.Get(ts.URL + "/dump") 42 | require.NoError(t, err, "Get(/dump)") 43 | require.Equal(t, http.StatusOK, resp.StatusCode) 44 | 45 | dump, err := io.ReadAll(resp.Body) 46 | resp.Body.Close() 47 | require.NoError(t, err, "ReadAll") 48 | 49 | var data map[string]any 50 | require.NoError(t, json.Unmarshal(dump, &data), "Unmarshal") 51 | test, ok := data["test"] 52 | require.True(t, ok) 53 | require.Len(t, test, tbl.NumObjects(db.ReadTxn())) 54 | 55 | resp, err = http.Get(ts.URL + "/dump/test") 56 | require.NoError(t, err, "Get(/dump/test)") 57 | require.Equal(t, http.StatusOK, resp.StatusCode) 58 | 59 | dump, err = io.ReadAll(resp.Body) 60 | resp.Body.Close() 61 | if err != nil { 62 | t.Fatal(err) 63 | } 64 | 65 | require.NoError(t, json.Unmarshal(dump, &data), "Unmarshal") 66 | test, ok = data["test"] 67 | require.True(t, ok) 68 | require.Len(t, test, tbl.NumObjects(db.ReadTxn())) 69 | } 70 | 71 | func Test_http_runQuery(t *testing.T) { 72 | db, table, _ := httpFixture(t) 73 | txn := db.ReadTxn() 74 | 75 | // idIndex, unique 76 | indexTxn, err := txn.getTxn().indexReadTxn(table, table.indexPos(idIndex.Name)) 77 | require.NoError(t, err) 78 | items := []object{} 79 | onObject := func(obj object) error { 80 | items = append(items, obj) 81 | return nil 82 | } 83 | runQuery(indexTxn, false, index.Uint64(1), onObject) 84 | if assert.Len(t, items, 1) { 85 | assert.EqualValues(t, items[0].data.(testObject).ID, 1) 86 | } 87 | 88 | // tagsIndex, non-unique 89 | indexTxn, err = txn.getTxn().indexReadTxn(table, table.indexPos(tagsIndex.Name)) 90 | require.NoError(t, err) 91 | items = nil 92 | runQuery(indexTxn, false, index.String("foo"), onObject) 93 | 94 | if assert.Len(t, items, 2) { 95 | assert.EqualValues(t, items[0].data.(testObject).ID, 1) 96 | assert.EqualValues(t, items[1].data.(testObject).ID, 2) 97 | } 98 | 99 | // lower-bound on revision index 100 | indexTxn, err = txn.getTxn().indexReadTxn(table, RevisionIndexPos) 101 | require.NoError(t, err) 102 | items = nil 103 | runQuery(indexTxn, true, index.Uint64(0), onObject) 104 | if assert.Len(t, items, 4) { 105 | // Items are in revision (creation) order 106 | assert.EqualValues(t, items[0].data.(testObject).ID, 1) 107 | assert.EqualValues(t, items[1].data.(testObject).ID, 2) 108 | assert.EqualValues(t, items[2].data.(testObject).ID, 3) 109 | assert.EqualValues(t, items[3].data.(testObject).ID, 4) 110 | } 111 | } 112 | 113 | func Test_http_RemoteTable_Get_LowerBound(t *testing.T) { 114 | ctx := context.TODO() 115 | _, table, ts := httpFixture(t) 116 | 117 | base, err := url.Parse(ts.URL) 118 | require.NoError(t, err, "ParseURL") 119 | 120 | remoteTable := NewRemoteTable[testObject](base, table.Name()) 121 | 122 | iter, errs := remoteTable.Get(ctx, idIndex.Query(1)) 123 | items := Collect(iter) 124 | assert.NoError(t, <-errs, "Get(1)") 125 | if assert.Len(t, items, 1) { 126 | assert.EqualValues(t, 1, items[0].ID) 127 | } 128 | 129 | iter, errs = remoteTable.LowerBound(ctx, idIndex.Query(0)) 130 | items = Collect(iter) 131 | assert.NoError(t, <-errs, "LowerBound(0)") 132 | if assert.Len(t, items, 4) { 133 | assert.EqualValues(t, 1, items[0].ID) 134 | assert.EqualValues(t, 2, items[1].ID) 135 | assert.EqualValues(t, 3, items[2].ID) 136 | assert.EqualValues(t, 4, items[3].ID) 137 | } 138 | } 139 | 140 | func Test_http_RemoteTable_Changes(t *testing.T) { 141 | ctx, cancel := context.WithCancel(context.TODO()) 142 | db, table, ts := httpFixture(t) 143 | 144 | base, err := url.Parse(ts.URL) 145 | require.NoError(t, err, "ParseURL") 146 | 147 | remoteTable := NewRemoteTable[testObject](base, table.Name()) 148 | 149 | iter, errs := remoteTable.LowerBound(ctx, idIndex.Query(0)) 150 | items := Collect(iter) 151 | require.NoError(t, <-errs, "LowerBound(0)") 152 | require.Len(t, items, 4) 153 | 154 | changes, errs := remoteTable.Changes(ctx) 155 | // Consume the changes via a channel so it is easier to assert. 156 | changesChan := make(chan Change[testObject], 1) 157 | go func() { 158 | defer close(changesChan) 159 | for change := range changes { 160 | changesChan <- change 161 | } 162 | }() 163 | 164 | for _, item := range items { 165 | change := <-changesChan 166 | assert.NotZero(t, change.Revision) 167 | assert.False(t, change.Deleted) 168 | assert.Equal(t, item.ID, change.Object.ID) 169 | } 170 | 171 | wtxn := db.WriteTxn(table) 172 | _, _, err = table.Insert(wtxn, testObject{ID: 5}) 173 | require.NoError(t, err, "Insert") 174 | _, _, err = table.Delete(wtxn, testObject{ID: 1}) 175 | require.NoError(t, err, "Delete") 176 | wtxn.Commit() 177 | 178 | change := <-changesChan 179 | assert.NotZero(t, change.Revision) 180 | assert.False(t, change.Deleted) 181 | assert.EqualValues(t, 5, change.Object.ID) 182 | 183 | change = <-changesChan 184 | assert.NotZero(t, change.Revision) 185 | assert.True(t, change.Deleted) 186 | assert.EqualValues(t, 1, change.Object.ID) 187 | 188 | cancel() 189 | 190 | change, ok := <-changesChan 191 | assert.False(t, ok) 192 | 193 | err = <-errs 194 | require.ErrorIs(t, err, context.Canceled) 195 | } 196 | -------------------------------------------------------------------------------- /index/bool.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package index 5 | 6 | import "strconv" 7 | 8 | var ( 9 | trueKey = []byte{'T'} 10 | falseKey = []byte{'F'} 11 | ) 12 | 13 | func Bool(b bool) Key { 14 | if b { 15 | return trueKey 16 | } 17 | return falseKey 18 | } 19 | 20 | func BoolString(s string) (Key, error) { 21 | b, err := strconv.ParseBool(s) 22 | return Bool(b), err 23 | } 24 | -------------------------------------------------------------------------------- /index/int.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package index 5 | 6 | import ( 7 | "encoding/binary" 8 | "strconv" 9 | ) 10 | 11 | // The indexing functions on integers should use big-endian encoding. 12 | // This allows prefix searching on integers as the most significant 13 | // byte is first. 14 | // For example to find 16-bit key larger than 260 (0x0104) from 3 (0x0003) 15 | // and 270 (0x0109) 16 | // 00 (3) < 01 (260) => skip, 17 | // 01 (270) >= 01 (260) => 09 > 04 => found! 18 | 19 | func Int(n int) Key { 20 | return Int32(int32(n)) 21 | } 22 | 23 | func IntString(s string) (Key, error) { 24 | return Int32String(s) 25 | } 26 | 27 | func Int64(n int64) Key { 28 | return Uint64(uint64(n)) 29 | } 30 | 31 | func Int64String(s string) (Key, error) { 32 | n, err := strconv.ParseInt(s, 10, 64) 33 | if err != nil { 34 | return Key{}, err 35 | } 36 | return Uint64(uint64(n)), nil 37 | } 38 | 39 | func Int32(n int32) Key { 40 | return Uint32(uint32(n)) 41 | } 42 | 43 | func Int32String(s string) (Key, error) { 44 | n, err := strconv.ParseInt(s, 10, 32) 45 | if err != nil { 46 | return Key{}, err 47 | } 48 | return Uint32(uint32(n)), nil 49 | } 50 | 51 | func Int16(n int16) Key { 52 | return Uint16(uint16(n)) 53 | } 54 | 55 | func Int16String(s string) (Key, error) { 56 | n, err := strconv.ParseInt(s, 10, 16) 57 | if err != nil { 58 | return Key{}, err 59 | } 60 | return Uint16(uint16(n)), nil 61 | } 62 | 63 | func Uint64(n uint64) Key { 64 | return binary.BigEndian.AppendUint64(nil, n) 65 | } 66 | 67 | func Uint64String(s string) (Key, error) { 68 | n, err := strconv.ParseUint(s, 10, 64) 69 | if err != nil { 70 | return Key{}, err 71 | } 72 | return Uint64(n), nil 73 | } 74 | 75 | func Uint32(n uint32) Key { 76 | return binary.BigEndian.AppendUint32(nil, n) 77 | } 78 | 79 | func Uint32String(s string) (Key, error) { 80 | n, err := strconv.ParseUint(s, 10, 32) 81 | if err != nil { 82 | return Key{}, err 83 | } 84 | return Uint32(uint32(n)), nil 85 | } 86 | 87 | func Uint16(n uint16) Key { 88 | return binary.BigEndian.AppendUint16(nil, n) 89 | } 90 | 91 | func Uint16String(s string) (Key, error) { 92 | n, err := strconv.ParseUint(s, 10, 16) 93 | if err != nil { 94 | return Key{}, err 95 | } 96 | return Uint16(uint16(n)), nil 97 | } 98 | -------------------------------------------------------------------------------- /index/keyset.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package index 5 | 6 | import ( 7 | "bytes" 8 | ) 9 | 10 | // Key is a byte slice describing a key used in an index by statedb. 11 | type Key []byte 12 | 13 | func (k Key) Equal(k2 Key) bool { 14 | return bytes.Equal(k, k2) 15 | } 16 | 17 | type KeySet struct { 18 | head Key 19 | tail []Key 20 | } 21 | 22 | func (ks KeySet) First() Key { 23 | return ks.head 24 | } 25 | 26 | func (ks KeySet) Foreach(fn func(Key)) { 27 | if ks.head == nil { 28 | return 29 | } 30 | fn(ks.head) 31 | for _, k := range ks.tail { 32 | fn(k) 33 | } 34 | } 35 | 36 | func (ks KeySet) Exists(k Key) bool { 37 | if ks.head.Equal(k) { 38 | return true 39 | } 40 | for _, k2 := range ks.tail { 41 | if k2.Equal(k) { 42 | return true 43 | } 44 | } 45 | return false 46 | } 47 | 48 | func NewKeySet(keys ...Key) KeySet { 49 | if len(keys) == 0 { 50 | return KeySet{} 51 | } 52 | return KeySet{keys[0], keys[1:]} 53 | } 54 | -------------------------------------------------------------------------------- /index/keyset_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package index_test 5 | 6 | import ( 7 | "testing" 8 | 9 | "github.com/stretchr/testify/require" 10 | 11 | "github.com/cilium/statedb/index" 12 | ) 13 | 14 | func TestKeySet_Single(t *testing.T) { 15 | ks := index.NewKeySet([]byte("baz")) 16 | require.EqualValues(t, "baz", ks.First()) 17 | require.True(t, ks.Exists([]byte("baz"))) 18 | require.False(t, ks.Exists([]byte("foo"))) 19 | vs := []index.Key{} 20 | ks.Foreach(func(bs index.Key) { 21 | vs = append(vs, bs) 22 | }) 23 | require.ElementsMatch(t, vs, []index.Key{index.Key("baz")}) 24 | } 25 | 26 | func TestKeySet_Multi(t *testing.T) { 27 | ks := index.NewKeySet([]byte("baz"), []byte("quux")) 28 | require.EqualValues(t, "baz", ks.First()) 29 | require.True(t, ks.Exists([]byte("baz"))) 30 | require.True(t, ks.Exists([]byte("quux"))) 31 | require.False(t, ks.Exists([]byte("foo"))) 32 | vs := [][]byte{} 33 | ks.Foreach(func(bs index.Key) { 34 | vs = append(vs, bs) 35 | }) 36 | require.ElementsMatch(t, vs, [][]byte{[]byte("baz"), []byte("quux")}) 37 | } 38 | -------------------------------------------------------------------------------- /index/map.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package index 5 | 6 | func StringMap[V any](m map[string]V) KeySet { 7 | keys := make([]Key, 0, len(m)) 8 | for k := range m { 9 | keys = append(keys, String(k)) 10 | } 11 | return NewKeySet(keys...) 12 | } 13 | -------------------------------------------------------------------------------- /index/netip.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package index 5 | 6 | import ( 7 | "bytes" 8 | "net" 9 | "net/netip" 10 | ) 11 | 12 | func NetIP(ip net.IP) Key { 13 | // Use the 16-byte form to have a constant-size key. 14 | return bytes.Clone(ip.To16()) 15 | } 16 | 17 | func NetIPAddr(addr netip.Addr) Key { 18 | // Use the 16-byte form to have a constant-size key. 19 | buf := addr.As16() 20 | return buf[:] 21 | } 22 | 23 | func NetIPAddrString(s string) (Key, error) { 24 | addr, err := netip.ParseAddr(s) 25 | if err != nil { 26 | return Key{}, err 27 | } 28 | return NetIPAddr(addr), nil 29 | } 30 | 31 | func NetIPPrefix(prefix netip.Prefix) Key { 32 | // Use the 16-byte form plus bits to have a constant-size key. 33 | addrBytes := prefix.Addr().As16() 34 | return append(addrBytes[:], uint8(prefix.Bits())) 35 | } 36 | 37 | func NetIPPrefixString(s string) (Key, error) { 38 | prefix, err := netip.ParsePrefix(s) 39 | if err != nil { 40 | return Key{}, err 41 | } 42 | return NetIPPrefix(prefix), nil 43 | } 44 | -------------------------------------------------------------------------------- /index/seq.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package index 5 | 6 | import "iter" 7 | 8 | // Seq creates a KeySet from an iter.Seq[T] with the given indexing function. 9 | // Example usage: 10 | // 11 | // var strings iter.Seq[string] 12 | // keys := Seq[string](index.String, strings) 13 | func Seq[T any]( 14 | toKey func(T) Key, 15 | seq iter.Seq[T], 16 | ) KeySet { 17 | keys := []Key{} 18 | for v := range seq { 19 | keys = append(keys, toKey(v)) 20 | } 21 | return NewKeySet(keys...) 22 | } 23 | 24 | // Seq2 creates a KeySet from an iter.Seq2[A,B] with the given indexing function. 25 | // Example usage: 26 | // 27 | // var seq iter.Seq2[string, int] 28 | // keys := Seq2(index.String, seq) 29 | func Seq2[A, B any]( 30 | toKey func(A) Key, 31 | seq iter.Seq2[A, B], 32 | ) KeySet { 33 | keys := []Key{} 34 | for a := range seq { 35 | keys = append(keys, toKey(a)) 36 | } 37 | return NewKeySet(keys...) 38 | } 39 | -------------------------------------------------------------------------------- /index/seq_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package index_test 5 | 6 | import ( 7 | "maps" 8 | "slices" 9 | "testing" 10 | 11 | "github.com/cilium/statedb/index" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestSeq(t *testing.T) { 16 | tests := [][]uint64{ 17 | {}, 18 | {1}, 19 | {1, 2, 3}, 20 | } 21 | for _, keys := range tests { 22 | expected := []index.Key{} 23 | for _, x := range keys { 24 | expected = append(expected, index.Uint64(x)) 25 | } 26 | keySet := index.Seq(index.Uint64, slices.Values(keys)) 27 | actual := []index.Key{} 28 | keySet.Foreach(func(k index.Key) { 29 | actual = append(actual, k) 30 | }) 31 | assert.ElementsMatch(t, expected, actual) 32 | } 33 | } 34 | 35 | func TestSeq2(t *testing.T) { 36 | tests := []map[uint64]int{ 37 | nil, 38 | map[uint64]int{}, 39 | map[uint64]int{1: 1}, 40 | map[uint64]int{1: 1, 2: 2, 3: 3}, 41 | } 42 | for _, m := range tests { 43 | expected := []index.Key{} 44 | for x := range m { 45 | expected = append(expected, index.Uint64(x)) 46 | } 47 | keySet := index.Seq2(index.Uint64, maps.All(m)) 48 | actual := []index.Key{} 49 | keySet.Foreach(func(k index.Key) { 50 | actual = append(actual, k) 51 | }) 52 | assert.ElementsMatch(t, expected, actual) 53 | } 54 | } 55 | 56 | type testObj struct { 57 | x string 58 | } 59 | 60 | func (t testObj) String() string { 61 | return t.x 62 | } 63 | 64 | func TestStringerSeq(t *testing.T) { 65 | tests := [][]testObj{ 66 | {}, 67 | {testObj{"foo"}}, 68 | {testObj{"foo"}, testObj{"bar"}}, 69 | } 70 | for _, objs := range tests { 71 | expected := []index.Key{} 72 | for _, o := range objs { 73 | expected = append(expected, index.String(o.x)) 74 | } 75 | keySet := index.StringerSeq(slices.Values(objs)) 76 | actual := []index.Key{} 77 | keySet.Foreach(func(k index.Key) { 78 | actual = append(actual, k) 79 | }) 80 | assert.ElementsMatch(t, expected, actual) 81 | } 82 | } 83 | 84 | func TestStringerSeq2(t *testing.T) { 85 | tests := []map[testObj]int{ 86 | {}, 87 | {testObj{"foo"}: 1}, 88 | {testObj{"foo"}: 1, testObj{"bar"}: 2}, 89 | } 90 | for _, m := range tests { 91 | expected := []index.Key{} 92 | for o := range m { 93 | expected = append(expected, index.String(o.x)) 94 | } 95 | keySet := index.StringerSeq2(maps.All(m)) 96 | actual := []index.Key{} 97 | keySet.Foreach(func(k index.Key) { 98 | actual = append(actual, k) 99 | }) 100 | assert.ElementsMatch(t, expected, actual) 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /index/set.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package index 5 | 6 | import "github.com/cilium/statedb/part" 7 | 8 | // Set creates a KeySet from a part.Set. 9 | func Set[T any](s part.Set[T]) KeySet { 10 | toBytes := s.ToBytesFunc() 11 | switch s.Len() { 12 | case 0: 13 | return NewKeySet() 14 | case 1: 15 | for v := range s.All() { 16 | return NewKeySet(toBytes(v)) 17 | } 18 | panic("BUG: Set.Len() == 1, but ranging returned nothing") 19 | default: 20 | keys := make([]Key, 0, s.Len()) 21 | for v := range s.All() { 22 | keys = append(keys, toBytes(v)) 23 | } 24 | return NewKeySet(keys...) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /index/string.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package index 5 | 6 | import ( 7 | "fmt" 8 | "iter" 9 | ) 10 | 11 | func String(s string) Key { 12 | return []byte(s) 13 | } 14 | 15 | func FromString(s string) (Key, error) { 16 | return String(s), nil 17 | } 18 | 19 | func Stringer[T fmt.Stringer](s T) Key { 20 | return String(s.String()) 21 | } 22 | 23 | func StringSlice(ss []string) KeySet { 24 | keys := make([]Key, 0, len(ss)) 25 | for _, s := range ss { 26 | keys = append(keys, String(s)) 27 | } 28 | return NewKeySet(keys...) 29 | } 30 | 31 | func StringerSlice[T fmt.Stringer](ss []T) KeySet { 32 | keys := make([]Key, 0, len(ss)) 33 | for _, s := range ss { 34 | keys = append(keys, Stringer(s)) 35 | } 36 | return NewKeySet(keys...) 37 | } 38 | 39 | func StringerSeq[T fmt.Stringer](seq iter.Seq[T]) KeySet { 40 | return Seq[T](Stringer, seq) 41 | } 42 | 43 | func StringerSeq2[A fmt.Stringer, B any](seq iter.Seq2[A, B]) KeySet { 44 | return Seq2[A, B](Stringer, seq) 45 | } 46 | -------------------------------------------------------------------------------- /internal/sortable_mutex.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package internal 5 | 6 | import ( 7 | "sort" 8 | "sync" 9 | "sync/atomic" 10 | "time" 11 | ) 12 | 13 | // sortableMutexSeq is a global sequence counter for the creation of new 14 | // SortableMutex's with unique sequence numbers. 15 | var sortableMutexSeq atomic.Uint64 16 | 17 | // sortableMutex implements SortableMutex. Not exported as the only way to 18 | // initialize it is via NewSortableMutex(). 19 | type sortableMutex struct { 20 | sync.Mutex 21 | seq uint64 22 | acquireDuration time.Duration 23 | } 24 | 25 | func (s *sortableMutex) Lock() { 26 | start := time.Now() 27 | s.Mutex.Lock() 28 | s.acquireDuration = time.Since(start) 29 | } 30 | 31 | func (s *sortableMutex) Seq() uint64 { return s.seq } 32 | 33 | func (s *sortableMutex) AcquireDuration() time.Duration { return s.acquireDuration } 34 | 35 | // SortableMutex provides a Mutex that can be globally sorted with other 36 | // sortable mutexes. This allows deadlock-safe locking of a set of mutexes 37 | // as it guarantees consistent lock ordering. 38 | type SortableMutex interface { 39 | sync.Locker 40 | Seq() uint64 41 | AcquireDuration() time.Duration // The amount of time it took to acquire the lock 42 | } 43 | 44 | // SortableMutexes is a set of mutexes that can be locked in a safe order. 45 | // Once Lock() is called it must not be mutated! 46 | type SortableMutexes []SortableMutex 47 | 48 | // Len implements sort.Interface. 49 | func (s SortableMutexes) Len() int { 50 | return len(s) 51 | } 52 | 53 | // Less implements sort.Interface. 54 | func (s SortableMutexes) Less(i int, j int) bool { 55 | return s[i].Seq() < s[j].Seq() 56 | } 57 | 58 | // Swap implements sort.Interface. 59 | func (s SortableMutexes) Swap(i int, j int) { 60 | s[i], s[j] = s[j], s[i] 61 | } 62 | 63 | // Lock sorts the mutexes, and then locks them in order. If any lock cannot be acquired, 64 | // this will block while holding the locks with a lower sequence number. 65 | func (s SortableMutexes) Lock() { 66 | sort.Sort(s) 67 | for _, mu := range s { 68 | mu.Lock() 69 | } 70 | } 71 | 72 | // Unlock locks the sorted set of mutexes locked by prior call to Lock(). 73 | func (s SortableMutexes) Unlock() { 74 | for _, mu := range s { 75 | mu.Unlock() 76 | } 77 | } 78 | 79 | var _ sort.Interface = SortableMutexes{} 80 | 81 | func NewSortableMutex() SortableMutex { 82 | seq := sortableMutexSeq.Add(1) 83 | return &sortableMutex{ 84 | seq: seq, 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /internal/sortable_mutex_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package internal 5 | 6 | import ( 7 | "math/rand" 8 | "slices" 9 | "sync" 10 | "testing" 11 | "time" 12 | 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestSortableMutex(t *testing.T) { 17 | smu1 := NewSortableMutex() 18 | smu2 := NewSortableMutex() 19 | require.Greater(t, smu2.Seq(), smu1.Seq()) 20 | smu1.Lock() 21 | smu2.Lock() 22 | smu1.Unlock() 23 | smu2.Unlock() 24 | smus := SortableMutexes{smu1, smu2} 25 | smus.Lock() 26 | smus.Unlock() 27 | smus.Lock() 28 | smus.Unlock() 29 | } 30 | 31 | func TestSortableMutex_Chaos(t *testing.T) { 32 | smus := SortableMutexes{ 33 | NewSortableMutex(), 34 | NewSortableMutex(), 35 | NewSortableMutex(), 36 | NewSortableMutex(), 37 | NewSortableMutex(), 38 | } 39 | 40 | nMonkeys := 10 41 | iterations := 100 42 | var wg sync.WaitGroup 43 | wg.Add(nMonkeys) 44 | 45 | monkey := func() { 46 | defer wg.Done() 47 | for i := 0; i < iterations; i++ { 48 | // Take a random subset of the sortable mutexes. 49 | subSmus := slices.Clone(smus) 50 | rand.Shuffle(len(subSmus), func(i, j int) { 51 | subSmus[i], subSmus[j] = subSmus[j], subSmus[i] 52 | }) 53 | n := rand.Intn(len(subSmus)) 54 | subSmus = subSmus[:n] 55 | 56 | time.Sleep(time.Microsecond) 57 | subSmus.Lock() 58 | time.Sleep(time.Microsecond) 59 | subSmus.Unlock() 60 | time.Sleep(time.Microsecond) 61 | } 62 | } 63 | 64 | for i := 0; i < nMonkeys; i++ { 65 | go monkey() 66 | } 67 | 68 | wg.Wait() 69 | } 70 | -------------------------------------------------------------------------------- /internal/time.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package internal 5 | 6 | import ( 7 | "fmt" 8 | "time" 9 | ) 10 | 11 | func PrettySince(t time.Time) string { 12 | return PrettyDuration(time.Since(t)) 13 | } 14 | 15 | func PrettyDuration(d time.Duration) string { 16 | ago := float64(d) / float64(time.Microsecond) 17 | 18 | // micros 19 | if ago < 1000.0 { 20 | return fmt.Sprintf("%.1fus", ago) 21 | } 22 | 23 | // millis 24 | ago /= 1000.0 25 | if ago < 1000.0 { 26 | return fmt.Sprintf("%.1fms", ago) 27 | } 28 | // secs 29 | ago /= 1000.0 30 | if ago < 60.0 { 31 | return fmt.Sprintf("%.1fs", ago) 32 | } 33 | // mins 34 | ago /= 60.0 35 | if ago < 60.0 { 36 | return fmt.Sprintf("%.1fm", ago) 37 | } 38 | // hours 39 | ago /= 60.0 40 | return fmt.Sprintf("%.1fh", ago) 41 | } 42 | -------------------------------------------------------------------------------- /iterator_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package statedb 5 | 6 | import ( 7 | "testing" 8 | 9 | "github.com/cilium/statedb/index" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestCollectFilterMapToSeq(t *testing.T) { 15 | type testObject struct { 16 | ID int 17 | } 18 | 19 | db := New() 20 | idIndex := Index[*testObject, int]{ 21 | Name: "id", 22 | FromObject: func(t *testObject) index.KeySet { 23 | return index.NewKeySet(index.Int(t.ID)) 24 | }, 25 | FromKey: index.Int, 26 | Unique: true, 27 | } 28 | table, _ := NewTable("test", idIndex) 29 | require.NoError(t, db.RegisterTable(table)) 30 | db.Start() 31 | defer db.Stop() 32 | 33 | txn := db.WriteTxn(table) 34 | table.Insert(txn, &testObject{ID: 1}) 35 | table.Insert(txn, &testObject{ID: 2}) 36 | table.Insert(txn, &testObject{ID: 3}) 37 | table.Insert(txn, &testObject{ID: 4}) 38 | table.Insert(txn, &testObject{ID: 5}) 39 | txn.Commit() 40 | 41 | iter := table.All(db.ReadTxn()) 42 | filtered := Collect( 43 | Map( 44 | Filter( 45 | iter, 46 | func(obj *testObject) bool { 47 | return obj.ID%2 == 0 48 | }, 49 | ), 50 | func(obj *testObject) int { 51 | return obj.ID 52 | }, 53 | ), 54 | ) 55 | assert.Len(t, filtered, 2) 56 | assert.Equal(t, []int{2, 4}, filtered) 57 | 58 | count := 0 59 | for obj := range ToSeq(iter) { 60 | assert.Greater(t, obj.ID, 0) 61 | count++ 62 | } 63 | assert.Equal(t, 5, count) 64 | 65 | } 66 | -------------------------------------------------------------------------------- /metrics.go: -------------------------------------------------------------------------------- 1 | package statedb 2 | 3 | import ( 4 | "expvar" 5 | "fmt" 6 | "strings" 7 | "time" 8 | ) 9 | 10 | type Metrics interface { 11 | WriteTxnTableAcquisition(handle string, tableName string, acquire time.Duration) 12 | WriteTxnTotalAcquisition(handle string, tables []string, acquire time.Duration) 13 | WriteTxnDuration(handle string, tables []string, acquire time.Duration) 14 | 15 | GraveyardLowWatermark(tableName string, lowWatermark Revision) 16 | GraveyardCleaningDuration(tableName string, duration time.Duration) 17 | GraveyardObjectCount(tableName string, numDeletedObjects int) 18 | ObjectCount(tableName string, numObjects int) 19 | 20 | DeleteTrackerCount(tableName string, numTrackers int) 21 | Revision(tableName string, revision Revision) 22 | } 23 | 24 | // ExpVarMetrics is a simple implementation for the metrics. 25 | type ExpVarMetrics struct { 26 | LockContentionVar *expvar.Map 27 | GraveyardCleaningDurationVar *expvar.Map 28 | GraveyardLowWatermarkVar *expvar.Map 29 | GraveyardObjectCountVar *expvar.Map 30 | ObjectCountVar *expvar.Map 31 | WriteTxnAcquisitionVar *expvar.Map 32 | WriteTxnDurationVar *expvar.Map 33 | DeleteTrackerCountVar *expvar.Map 34 | RevisionVar *expvar.Map 35 | } 36 | 37 | func (m *ExpVarMetrics) String() (out string) { 38 | var b strings.Builder 39 | m.LockContentionVar.Do(func(kv expvar.KeyValue) { 40 | fmt.Fprintf(&b, "lock_contention[%s]: %s\n", kv.Key, kv.Value.String()) 41 | }) 42 | m.GraveyardCleaningDurationVar.Do(func(kv expvar.KeyValue) { 43 | fmt.Fprintf(&b, "graveyard_cleaning_duration[%s]: %s\n", kv.Key, kv.Value.String()) 44 | }) 45 | m.GraveyardLowWatermarkVar.Do(func(kv expvar.KeyValue) { 46 | fmt.Fprintf(&b, "graveyard_low_watermark[%s]: %s\n", kv.Key, kv.Value.String()) 47 | }) 48 | m.GraveyardObjectCountVar.Do(func(kv expvar.KeyValue) { 49 | fmt.Fprintf(&b, "graveyard_object_count[%s]: %s\n", kv.Key, kv.Value.String()) 50 | }) 51 | m.ObjectCountVar.Do(func(kv expvar.KeyValue) { 52 | fmt.Fprintf(&b, "object_count[%s]: %s\n", kv.Key, kv.Value.String()) 53 | }) 54 | m.WriteTxnAcquisitionVar.Do(func(kv expvar.KeyValue) { 55 | fmt.Fprintf(&b, "write_txn_acquisition[%s]: %s\n", kv.Key, kv.Value.String()) 56 | }) 57 | m.WriteTxnDurationVar.Do(func(kv expvar.KeyValue) { 58 | fmt.Fprintf(&b, "write_txn_duration[%s]: %s\n", kv.Key, kv.Value.String()) 59 | }) 60 | m.DeleteTrackerCountVar.Do(func(kv expvar.KeyValue) { 61 | fmt.Fprintf(&b, "delete_tracker_count[%s]: %s\n", kv.Key, kv.Value.String()) 62 | }) 63 | m.RevisionVar.Do(func(kv expvar.KeyValue) { 64 | fmt.Fprintf(&b, "revision[%s]: %s\n", kv.Key, kv.Value.String()) 65 | }) 66 | 67 | return b.String() 68 | } 69 | 70 | func NewExpVarMetrics(publish bool) *ExpVarMetrics { 71 | newMap := func(name string) *expvar.Map { 72 | if publish { 73 | return expvar.NewMap(name) 74 | } 75 | return new(expvar.Map).Init() 76 | } 77 | return &ExpVarMetrics{ 78 | LockContentionVar: newMap("lock_contention"), 79 | GraveyardCleaningDurationVar: newMap("graveyard_cleaning_duration"), 80 | GraveyardLowWatermarkVar: newMap("graveyard_low_watermark"), 81 | GraveyardObjectCountVar: newMap("graveyard_object_count"), 82 | ObjectCountVar: newMap("object_count"), 83 | WriteTxnAcquisitionVar: newMap("write_txn_acquisition"), 84 | WriteTxnDurationVar: newMap("write_txn_duration"), 85 | DeleteTrackerCountVar: newMap("delete_tracker_count"), 86 | RevisionVar: newMap("revision"), 87 | } 88 | } 89 | 90 | func (m *ExpVarMetrics) DeleteTrackerCount(name string, numTrackers int) { 91 | var intVar expvar.Int 92 | intVar.Set(int64(numTrackers)) 93 | m.DeleteTrackerCountVar.Set(name, &intVar) 94 | } 95 | 96 | func (m *ExpVarMetrics) Revision(name string, revision uint64) { 97 | var intVar expvar.Int 98 | intVar.Set(int64(revision)) 99 | m.RevisionVar.Set(name, &intVar) 100 | } 101 | 102 | func (m *ExpVarMetrics) GraveyardCleaningDuration(name string, duration time.Duration) { 103 | m.GraveyardCleaningDurationVar.AddFloat(name, duration.Seconds()) 104 | } 105 | 106 | func (m *ExpVarMetrics) GraveyardLowWatermark(name string, lowWatermark Revision) { 107 | var intVar expvar.Int 108 | intVar.Set(int64(lowWatermark)) // unfortunately overflows at 2^63 109 | m.GraveyardLowWatermarkVar.Set(name, &intVar) 110 | } 111 | 112 | func (m *ExpVarMetrics) GraveyardObjectCount(name string, numDeletedObjects int) { 113 | var intVar expvar.Int 114 | intVar.Set(int64(numDeletedObjects)) 115 | m.GraveyardObjectCountVar.Set(name, &intVar) 116 | } 117 | 118 | func (m *ExpVarMetrics) ObjectCount(name string, numObjects int) { 119 | var intVar expvar.Int 120 | intVar.Set(int64(numObjects)) 121 | m.ObjectCountVar.Set(name, &intVar) 122 | } 123 | 124 | func (m *ExpVarMetrics) WriteTxnDuration(handle string, tables []string, acquire time.Duration) { 125 | m.WriteTxnDurationVar.AddFloat(handle+"/"+strings.Join(tables, "+"), acquire.Seconds()) 126 | } 127 | 128 | func (m *ExpVarMetrics) WriteTxnTotalAcquisition(handle string, tables []string, acquire time.Duration) { 129 | m.WriteTxnAcquisitionVar.AddFloat(handle+"/"+strings.Join(tables, "+"), acquire.Seconds()) 130 | } 131 | 132 | func (m *ExpVarMetrics) WriteTxnTableAcquisition(handle string, tableName string, acquire time.Duration) { 133 | m.LockContentionVar.AddFloat(handle+"/"+tableName, acquire.Seconds()) 134 | } 135 | 136 | var _ Metrics = &ExpVarMetrics{} 137 | 138 | type NopMetrics struct{} 139 | 140 | // DeleteTrackerCount implements Metrics. 141 | func (*NopMetrics) DeleteTrackerCount(tableName string, numTrackers int) { 142 | } 143 | 144 | // GraveyardCleaningDuration implements Metrics. 145 | func (*NopMetrics) GraveyardCleaningDuration(tableName string, duration time.Duration) { 146 | } 147 | 148 | // GraveyardLowWatermark implements Metrics. 149 | func (*NopMetrics) GraveyardLowWatermark(tableName string, lowWatermark uint64) { 150 | } 151 | 152 | // GraveyardObjectCount implements Metrics. 153 | func (*NopMetrics) GraveyardObjectCount(tableName string, numDeletedObjects int) { 154 | } 155 | 156 | // ObjectCount implements Metrics. 157 | func (*NopMetrics) ObjectCount(tableName string, numObjects int) { 158 | } 159 | 160 | // Revision implements Metrics. 161 | func (*NopMetrics) Revision(tableName string, revision uint64) { 162 | } 163 | 164 | // WriteTxnDuration implements Metrics. 165 | func (*NopMetrics) WriteTxnDuration(handle string, tables []string, acquire time.Duration) { 166 | } 167 | 168 | // WriteTxnTableAcquisition implements Metrics. 169 | func (*NopMetrics) WriteTxnTableAcquisition(handle string, tableName string, acquire time.Duration) { 170 | } 171 | 172 | // WriteTxnTotalAcquisition implements Metrics. 173 | func (*NopMetrics) WriteTxnTotalAcquisition(handle string, tables []string, acquire time.Duration) { 174 | } 175 | 176 | var _ Metrics = &NopMetrics{} 177 | -------------------------------------------------------------------------------- /observable.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package statedb 5 | 6 | import ( 7 | "context" 8 | 9 | "github.com/cilium/stream" 10 | ) 11 | 12 | // Observable creates an observable from the given table for observing the changes 13 | // to the table as a stream of events. 14 | // 15 | // For high-churn tables it's advisable to apply rate-limiting to the stream to 16 | // decrease overhead (stream.Throttle). 17 | func Observable[Obj any](db *DB, table Table[Obj]) stream.Observable[Change[Obj]] { 18 | return &observable[Obj]{db, table} 19 | } 20 | 21 | type observable[Obj any] struct { 22 | db *DB 23 | table Table[Obj] 24 | } 25 | 26 | func (to *observable[Obj]) Observe(ctx context.Context, next func(Change[Obj]), complete func(error)) { 27 | go func() { 28 | txn := to.db.WriteTxn(to.table) 29 | iter, err := to.table.Changes(txn) 30 | txn.Commit() 31 | if err != nil { 32 | complete(err) 33 | return 34 | } 35 | defer complete(nil) 36 | 37 | for { 38 | changes, watch := iter.Next(to.db.ReadTxn()) 39 | for change := range changes { 40 | if ctx.Err() != nil { 41 | break 42 | } 43 | next(change) 44 | } 45 | select { 46 | case <-ctx.Done(): 47 | return 48 | case <-watch: 49 | } 50 | } 51 | }() 52 | } 53 | -------------------------------------------------------------------------------- /part/cache.go: -------------------------------------------------------------------------------- 1 | package part 2 | 3 | import "unsafe" 4 | 5 | const nodeMutatedSize = 32 // must be power-of-two 6 | 7 | type nodeMutated[T any] struct { 8 | ptrs [nodeMutatedSize]*header[T] 9 | used bool 10 | } 11 | 12 | func (p *nodeMutated[T]) put(ptr *header[T]) { 13 | ptrInt := uintptr(unsafe.Pointer(ptr)) 14 | p.ptrs[slot(ptrInt)] = ptr 15 | p.used = true 16 | } 17 | 18 | func (p *nodeMutated[T]) exists(ptr *header[T]) bool { 19 | ptrInt := uintptr(unsafe.Pointer(ptr)) 20 | return p.ptrs[slot(ptrInt)] == ptr 21 | } 22 | 23 | func slot(p uintptr) int { 24 | var slot uint8 25 | // use some relevant bits from the pointer 26 | slot = slot + uint8(p>>4) 27 | slot = slot + uint8(p>>12) 28 | slot = slot + uint8(p>>20) 29 | return int(slot & (nodeMutatedSize - 1)) 30 | } 31 | 32 | func (p *nodeMutated[T]) clear() { 33 | if p.used { 34 | clear(p.ptrs[:]) 35 | } 36 | p.used = false 37 | } 38 | -------------------------------------------------------------------------------- /part/iterator.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package part 5 | 6 | import ( 7 | "bytes" 8 | "slices" 9 | "sort" 10 | ) 11 | 12 | // Iterator for key and value pairs where value is of type T 13 | type Iterator[T any] struct { 14 | next [][]*header[T] // sets of edges to explore 15 | } 16 | 17 | // Clone returns a copy of the iterator, allowing restarting 18 | // the iterator from scratch. 19 | func (it *Iterator[T]) Clone() *Iterator[T] { 20 | // Since the iterator does not mutate the edge array elements themselves ([]*header[T]) 21 | // it is enough to do a shallow clone here. 22 | return &Iterator[T]{slices.Clone(it.next)} 23 | } 24 | 25 | // Next returns the next key, value and true if the value exists, 26 | // otherwise it returns false. 27 | func (it *Iterator[T]) Next() (key []byte, value T, ok bool) { 28 | for len(it.next) > 0 { 29 | // Pop the next set of edges to explore 30 | edges := it.next[len(it.next)-1] 31 | for len(edges) > 0 && edges[0] == nil { 32 | // Node256 may have nil children, so jump over them. 33 | edges = edges[1:] 34 | } 35 | it.next = it.next[:len(it.next)-1] 36 | 37 | if len(edges) == 0 { 38 | continue 39 | } else if len(edges) > 1 { 40 | // More edges remain to be explored, add them back. 41 | it.next = append(it.next, edges[1:]) 42 | } 43 | 44 | // Follow the smallest edge and add its children to the queue. 45 | node := edges[0] 46 | 47 | if node.size() > 0 { 48 | it.next = append(it.next, node.children()) 49 | } 50 | if leaf := node.getLeaf(); leaf != nil { 51 | key = leaf.key 52 | value = leaf.value 53 | ok = true 54 | return 55 | } 56 | } 57 | return 58 | } 59 | 60 | func newIterator[T any](start *header[T]) *Iterator[T] { 61 | if start == nil { 62 | return &Iterator[T]{nil} 63 | } 64 | return &Iterator[T]{[][]*header[T]{{start}}} 65 | } 66 | 67 | func prefixSearch[T any](root *header[T], key []byte) (*Iterator[T], <-chan struct{}) { 68 | this := root 69 | var watch <-chan struct{} 70 | for { 71 | if !this.isLeaf() && this.watch != nil { 72 | // Leaf watch channels only close when the leaf is manipulated, 73 | // thus we only return non-leaf watch channels. 74 | watch = this.watch 75 | } 76 | 77 | switch { 78 | case bytes.Equal(key, this.prefix[:min(len(key), len(this.prefix))]): 79 | return newIterator(this), watch 80 | 81 | case bytes.HasPrefix(key, this.prefix): 82 | key = key[len(this.prefix):] 83 | if len(key) == 0 { 84 | return newIterator(this), this.watch 85 | } 86 | 87 | default: 88 | return newIterator[T](nil), root.watch 89 | } 90 | 91 | this = this.find(key[0]) 92 | if this == nil { 93 | return newIterator[T](nil), root.watch 94 | } 95 | } 96 | } 97 | 98 | func traverseToMin[T any](n *header[T], edges [][]*header[T]) [][]*header[T] { 99 | if leaf := n.getLeaf(); leaf != nil { 100 | return append(edges, []*header[T]{n}) 101 | } 102 | children := n.children() 103 | 104 | // Find the first non-nil child 105 | for len(children) > 0 && children[0] == nil { 106 | children = children[1:] 107 | } 108 | 109 | if len(children) > 0 { 110 | // Add the larger children. 111 | if len(children) > 1 { 112 | edges = append(edges, children[1:]) 113 | } 114 | // Recurse into the smallest child 115 | return traverseToMin(children[0], edges) 116 | } 117 | return edges 118 | } 119 | 120 | func lowerbound[T any](start *header[T], key []byte) *Iterator[T] { 121 | // The starting edges to explore. This contains all larger nodes encountered 122 | // on the path to the node larger or equal to the key. 123 | edges := [][]*header[T]{} 124 | this := start 125 | loop: 126 | for { 127 | switch bytes.Compare(this.prefix, key[:min(len(key), len(this.prefix))]) { 128 | case -1: 129 | // Prefix is smaller, stop here and return an iterator for 130 | // the larger nodes in the parent's. 131 | break loop 132 | 133 | case 0: 134 | if len(this.prefix) == len(key) { 135 | // Exact match. 136 | edges = append(edges, []*header[T]{this}) 137 | break loop 138 | } 139 | 140 | // Prefix matches the beginning of the key, but more 141 | // remains of the key. Drop the matching part and keep 142 | // going further. 143 | key = key[len(this.prefix):] 144 | 145 | if this.kind() == nodeKind256 { 146 | children := this.node256().children[:] 147 | idx := int(key[0]) 148 | this = children[idx] 149 | 150 | // Add all larger children and recurse further. 151 | children = children[idx+1:] 152 | for len(children) > 0 && children[0] == nil { 153 | children = children[1:] 154 | } 155 | edges = append(edges, children) 156 | 157 | if this == nil { 158 | break loop 159 | } 160 | } else { 161 | children := this.children() 162 | 163 | // Find the smallest child that is equal or larger than the lower bound 164 | idx := sort.Search(len(children), func(i int) bool { 165 | return children[i].prefix[0] >= key[0] 166 | }) 167 | if idx >= this.size() { 168 | break loop 169 | } 170 | // Add all larger children and recurse further. 171 | if len(children) > idx+1 { 172 | edges = append(edges, children[idx+1:]) 173 | } 174 | this = children[idx] 175 | } 176 | 177 | case 1: 178 | // Prefix bigger than lowerbound, go to smallest node and stop. 179 | edges = traverseToMin(this, edges) 180 | break loop 181 | } 182 | } 183 | 184 | if len(edges) > 0 { 185 | return &Iterator[T]{edges} 186 | } 187 | return &Iterator[T]{nil} 188 | } 189 | -------------------------------------------------------------------------------- /part/map.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package part 5 | 6 | import ( 7 | "bytes" 8 | "encoding/json" 9 | "fmt" 10 | "iter" 11 | "reflect" 12 | 13 | "gopkg.in/yaml.v3" 14 | ) 15 | 16 | // Map of key-value pairs. The zero value is ready for use, provided 17 | // that the key type has been registered with RegisterKeyType. 18 | // 19 | // Map is a typed wrapper around Tree[T] for working with 20 | // keys that are not []byte. 21 | type Map[K, V any] struct { 22 | bytesFromKey func(K) []byte 23 | tree *Tree[mapKVPair[K, V]] 24 | } 25 | 26 | type mapKVPair[K, V any] struct { 27 | Key K `json:"k" yaml:"k"` 28 | Value V `json:"v" yaml:"v"` 29 | } 30 | 31 | // FromMap copies values from the hash map into the given Map. 32 | // This is not implemented as a method on Map[K,V] as hash maps require the 33 | // comparable constraint and we do not need to limit Map[K, V] to that. 34 | func FromMap[K comparable, V any](m Map[K, V], hm map[K]V) Map[K, V] { 35 | m.ensureTree() 36 | txn := m.tree.Txn() 37 | for k, v := range hm { 38 | txn.Insert(m.bytesFromKey(k), mapKVPair[K, V]{k, v}) 39 | } 40 | m.tree = txn.CommitOnly() 41 | return m 42 | } 43 | 44 | // ensureTree checks that the tree is not nil and allocates it if 45 | // it is. The whole nil tree thing is to make sure that creating 46 | // an empty map does not allocate anything. 47 | func (m *Map[K, V]) ensureTree() { 48 | if m.tree == nil { 49 | m.tree = New[mapKVPair[K, V]](RootOnlyWatch) 50 | } 51 | m.bytesFromKey = lookupKeyType[K]() 52 | } 53 | 54 | // Get a value from the map by its key. 55 | func (m Map[K, V]) Get(key K) (value V, found bool) { 56 | if m.tree == nil { 57 | return 58 | } 59 | kv, _, found := m.tree.Get(m.bytesFromKey(key)) 60 | return kv.Value, found 61 | } 62 | 63 | // Set a value. Returns a new map with the value set. 64 | // Original map is unchanged. 65 | func (m Map[K, V]) Set(key K, value V) Map[K, V] { 66 | m.ensureTree() 67 | txn := m.tree.Txn() 68 | txn.Insert(m.bytesFromKey(key), mapKVPair[K, V]{key, value}) 69 | m.tree = txn.CommitOnly() 70 | return m 71 | } 72 | 73 | // Delete a value from the map. Returns a new map 74 | // without the element pointed to by the key (if found). 75 | func (m Map[K, V]) Delete(key K) Map[K, V] { 76 | if m.tree != nil { 77 | txn := m.tree.Txn() 78 | txn.Delete(m.bytesFromKey(key)) 79 | // Map is a struct passed by value, so we can modify 80 | // it without changing the caller's view of it. 81 | m.tree = txn.CommitOnly() 82 | } 83 | return m 84 | } 85 | 86 | func toSeq2[K, V any](iter *Iterator[mapKVPair[K, V]]) iter.Seq2[K, V] { 87 | return func(yield func(K, V) bool) { 88 | if iter == nil { 89 | return 90 | } 91 | iter = iter.Clone() 92 | for _, kv, ok := iter.Next(); ok; _, kv, ok = iter.Next() { 93 | if !yield(kv.Key, kv.Value) { 94 | break 95 | } 96 | } 97 | } 98 | } 99 | 100 | // LowerBound iterates over all keys in order with value equal 101 | // to or greater than [from]. 102 | func (m Map[K, V]) LowerBound(from K) iter.Seq2[K, V] { 103 | if m.tree == nil { 104 | return toSeq2[K, V](nil) 105 | } 106 | return toSeq2(m.tree.LowerBound(m.bytesFromKey(from))) 107 | } 108 | 109 | // Prefix iterates in order over all keys that start with 110 | // the given prefix. 111 | func (m Map[K, V]) Prefix(prefix K) iter.Seq2[K, V] { 112 | if m.tree == nil { 113 | return toSeq2[K, V](nil) 114 | } 115 | iter, _ := m.tree.Prefix(m.bytesFromKey(prefix)) 116 | return toSeq2(iter) 117 | } 118 | 119 | // All iterates every key-value in the map in order. 120 | // The order is in bytewise order of the byte slice 121 | // returned by bytesFromKey. 122 | func (m Map[K, V]) All() iter.Seq2[K, V] { 123 | if m.tree == nil { 124 | return toSeq2[K, V](nil) 125 | } 126 | return toSeq2(m.tree.Iterator()) 127 | } 128 | 129 | // EqualKeys returns true if both maps contain the same keys. 130 | func (m Map[K, V]) EqualKeys(other Map[K, V]) bool { 131 | switch { 132 | case m.tree == nil && other.tree == nil: 133 | return true 134 | case m.Len() != other.Len(): 135 | return false 136 | default: 137 | iter1 := m.tree.Iterator() 138 | iter2 := other.tree.Iterator() 139 | for { 140 | k1, _, ok := iter1.Next() 141 | if !ok { 142 | break 143 | } 144 | k2, _, _ := iter2.Next() 145 | // Equal lengths, no need to check 'ok' for 'iter2'. 146 | if !bytes.Equal(k1, k2) { 147 | return false 148 | } 149 | } 150 | return true 151 | } 152 | } 153 | 154 | // SlowEqual returns true if the two maps contain the same keys and values. 155 | // Value comparison is implemented with reflect.DeepEqual which makes this 156 | // slow and mostly useful for testing. 157 | func (m Map[K, V]) SlowEqual(other Map[K, V]) bool { 158 | switch { 159 | case m.tree == nil && other.tree == nil: 160 | return true 161 | case m.Len() != other.Len(): 162 | return false 163 | default: 164 | iter1 := m.tree.Iterator() 165 | iter2 := other.tree.Iterator() 166 | for { 167 | k1, v1, ok := iter1.Next() 168 | if !ok { 169 | break 170 | } 171 | k2, v2, _ := iter2.Next() 172 | // Equal lengths, no need to check 'ok' for 'iter2'. 173 | if !bytes.Equal(k1, k2) || !reflect.DeepEqual(v1, v2) { 174 | return false 175 | } 176 | } 177 | return true 178 | } 179 | } 180 | 181 | // Len returns the number of elements in the map. 182 | func (m Map[K, V]) Len() int { 183 | if m.tree == nil { 184 | return 0 185 | } 186 | return m.tree.size 187 | } 188 | 189 | func (m Map[K, V]) MarshalJSON() ([]byte, error) { 190 | if m.tree == nil { 191 | return []byte("[]"), nil 192 | } 193 | 194 | var b bytes.Buffer 195 | b.WriteRune('[') 196 | iter := m.tree.Iterator() 197 | _, kv, ok := iter.Next() 198 | for ok { 199 | bs, err := json.Marshal(kv) 200 | if err != nil { 201 | return nil, err 202 | } 203 | b.Write(bs) 204 | _, kv, ok = iter.Next() 205 | if ok { 206 | b.WriteRune(',') 207 | } 208 | } 209 | b.WriteRune(']') 210 | return b.Bytes(), nil 211 | } 212 | 213 | func (m *Map[K, V]) UnmarshalJSON(data []byte) error { 214 | dec := json.NewDecoder(bytes.NewReader(data)) 215 | t, err := dec.Token() 216 | if err != nil { 217 | return err 218 | } 219 | if d, ok := t.(json.Delim); !ok || d != '[' { 220 | return fmt.Errorf("%T.UnmarshalJSON: expected '[' got %v", m, t) 221 | } 222 | m.ensureTree() 223 | txn := m.tree.Txn() 224 | for dec.More() { 225 | var kv mapKVPair[K, V] 226 | err := dec.Decode(&kv) 227 | if err != nil { 228 | return err 229 | } 230 | txn.Insert(m.bytesFromKey(kv.Key), mapKVPair[K, V]{kv.Key, kv.Value}) 231 | } 232 | 233 | t, err = dec.Token() 234 | if err != nil { 235 | return err 236 | } 237 | if d, ok := t.(json.Delim); !ok || d != ']' { 238 | return fmt.Errorf("%T.UnmarshalJSON: expected ']' got %v", m, t) 239 | } 240 | m.tree = txn.CommitOnly() 241 | return nil 242 | } 243 | 244 | func (m Map[K, V]) MarshalYAML() (any, error) { 245 | kvs := make([]mapKVPair[K, V], 0, m.Len()) 246 | if m.tree != nil { 247 | iter := m.tree.Iterator() 248 | for _, kv, ok := iter.Next(); ok; _, kv, ok = iter.Next() { 249 | kvs = append(kvs, kv) 250 | } 251 | } 252 | return kvs, nil 253 | } 254 | 255 | func (m *Map[K, V]) UnmarshalYAML(value *yaml.Node) error { 256 | if value.Kind != yaml.SequenceNode { 257 | return fmt.Errorf("%T.UnmarshalYAML: expected sequence", m) 258 | } 259 | m.ensureTree() 260 | txn := m.tree.Txn() 261 | for _, e := range value.Content { 262 | var kv mapKVPair[K, V] 263 | if err := e.Decode(&kv); err != nil { 264 | return err 265 | } 266 | txn.Insert(m.bytesFromKey(kv.Key), mapKVPair[K, V]{kv.Key, kv.Value}) 267 | } 268 | m.tree = txn.CommitOnly() 269 | return nil 270 | } 271 | -------------------------------------------------------------------------------- /part/map_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package part_test 5 | 6 | import ( 7 | "encoding/json" 8 | "fmt" 9 | "iter" 10 | "math/rand/v2" 11 | "testing" 12 | 13 | "github.com/cilium/statedb/part" 14 | "github.com/stretchr/testify/assert" 15 | "github.com/stretchr/testify/require" 16 | "gopkg.in/yaml.v3" 17 | ) 18 | 19 | func TestStringMap(t *testing.T) { 20 | var m part.Map[string, int] 21 | 22 | // 23 | // Operations on empty map 24 | // 25 | 26 | v, ok := m.Get("nonexisting") 27 | assert.False(t, ok, "Get non-existing") 28 | assert.Equal(t, 0, v) 29 | 30 | assertIterEmpty := func(it iter.Seq2[string, int]) { 31 | t.Helper() 32 | for range it { 33 | t.Fatalf("expected empty iterator") 34 | } 35 | } 36 | assertIterEmpty(m.LowerBound("")) 37 | assertIterEmpty(m.Prefix("")) 38 | assertIterEmpty(m.All()) 39 | 40 | // 41 | // Operations on non-empty map 42 | // 43 | 44 | // Ordered list of key-value pairs we're testing with. 45 | // Prefix set so that Map keeps them in the same order. 46 | kvs := []struct { 47 | k string 48 | v int 49 | }{ 50 | {"1_one", 1}, 51 | {"2_two", 2}, 52 | {"3_three", 3}, 53 | } 54 | 55 | // Set some values in two different ways. 56 | m = m.Set("1_one", 1) 57 | m = part.FromMap(m, map[string]int{ 58 | "2_two": 2, 59 | "3_three": 3, 60 | }) 61 | 62 | // Setting on a copy doeen't affect original 63 | m.Set("4_four", 4) 64 | _, ok = m.Get("4_four") 65 | assert.False(t, ok, "Get non-existing") 66 | 67 | // Getting a non-existing value still does the same. 68 | v, ok = m.Get("nonexisting") 69 | assert.False(t, ok, "Get non-existing") 70 | assert.Equal(t, 0, v) 71 | 72 | for _, kv := range kvs { 73 | v, ok := m.Get(kv.k) 74 | assert.True(t, ok, "Get %q", kv.k) 75 | assert.Equal(t, v, kv.v) 76 | } 77 | 78 | expected := kvs 79 | for k, v := range m.All() { 80 | kv := expected[0] 81 | expected = expected[1:] 82 | assert.EqualValues(t, kv.k, k) 83 | assert.EqualValues(t, kv.v, v) 84 | } 85 | assert.Empty(t, expected) 86 | 87 | expected = kvs[1:] 88 | for k, v := range m.LowerBound("2") { 89 | kv := expected[0] 90 | expected = expected[1:] 91 | assert.EqualValues(t, kv.k, k) 92 | assert.EqualValues(t, kv.v, v) 93 | } 94 | assert.Empty(t, expected) 95 | 96 | expected = kvs[1:2] 97 | for k, v := range m.Prefix("2") { 98 | kv := expected[0] 99 | expected = expected[1:] 100 | assert.EqualValues(t, kv.k, k) 101 | assert.EqualValues(t, kv.v, v) 102 | } 103 | assert.Empty(t, expected) 104 | 105 | assert.Equal(t, 3, m.Len()) 106 | 107 | mOld := m 108 | m = m.Delete(kvs[0].k) 109 | _, ok = m.Get(kvs[0].k) 110 | assert.False(t, ok, "Get after Delete") 111 | 112 | _, ok = mOld.Get(kvs[0].k) 113 | assert.True(t, ok, "Original modified by Delete") 114 | mOld = mOld.Delete(kvs[0].k) 115 | _, ok = mOld.Get(kvs[0].k) 116 | assert.False(t, ok, "Get after Delete") 117 | 118 | assert.Equal(t, 2, m.Len()) 119 | } 120 | 121 | func TestUint64Map(t *testing.T) { 122 | // TestStringMap tests most of the operations. We just check here that 123 | // fromBytes and toBytes work and can iterate in the right order. 124 | var m part.Map[uint64, int] 125 | m = m.Set(42, 42) 126 | m = m.Set(55, 55) 127 | m = m.Set(72, 72) 128 | 129 | v, ok := m.Get(42) 130 | assert.True(t, ok, "Get 42") 131 | assert.Equal(t, 42, v) 132 | 133 | count := 0 134 | expected := []uint64{55, 72} 135 | for k, v := range m.LowerBound(55) { 136 | kv := expected[0] 137 | expected = expected[1:] 138 | assert.EqualValues(t, kv, k) 139 | assert.EqualValues(t, kv, v) 140 | count++ 141 | } 142 | assert.Equal(t, 2, count) 143 | } 144 | 145 | func TestRegisterKeyType(t *testing.T) { 146 | type testKey struct { 147 | X string 148 | } 149 | part.RegisterKeyType(func(k testKey) []byte { return []byte(k.X) }) 150 | 151 | var m part.Map[testKey, int] 152 | m = m.Set(testKey{"hello"}, 123) 153 | 154 | v, ok := m.Get(testKey{"hello"}) 155 | assert.True(t, ok, "Get 'hello'") 156 | assert.Equal(t, 123, v) 157 | 158 | for k, v := range m.All() { 159 | assert.Equal(t, testKey{"hello"}, k) 160 | assert.Equal(t, 123, v) 161 | } 162 | } 163 | 164 | func TestMapJSON(t *testing.T) { 165 | var m part.Map[string, int] 166 | m = m.Set("foo", 1).Set("bar", 2).Set("baz", 3) 167 | 168 | bs, err := json.Marshal(m) 169 | require.NoError(t, err, "Marshal") 170 | 171 | var m2 part.Map[string, int] 172 | err = json.Unmarshal(bs, &m2) 173 | require.NoError(t, err, "Unmarshal") 174 | require.True(t, m.SlowEqual(m2), "SlowEqual") 175 | } 176 | 177 | func TestMapYAMLStringKey(t *testing.T) { 178 | var m part.Map[string, int] 179 | 180 | bs, err := yaml.Marshal(m) 181 | require.NoError(t, err, "Marshal") 182 | require.Equal(t, "[]\n", string(bs)) 183 | 184 | m = m.Set("foo", 1).Set("bar", 2).Set("baz", 3) 185 | 186 | bs, err = yaml.Marshal(m) 187 | require.NoError(t, err, "Marshal") 188 | require.Equal(t, "- k: bar\n v: 2\n- k: baz\n v: 3\n- k: foo\n v: 1\n", string(bs)) 189 | 190 | var m2 part.Map[string, int] 191 | err = yaml.Unmarshal(bs, &m2) 192 | require.NoError(t, err, "Unmarshal") 193 | require.True(t, m.SlowEqual(m2), "SlowEqual") 194 | } 195 | 196 | func TestMapYAMLStructKey(t *testing.T) { 197 | type key struct { 198 | A int `yaml:"a"` 199 | B string `yaml:"b"` 200 | } 201 | part.RegisterKeyType[key](func(k key) []byte { 202 | return []byte(fmt.Sprintf("%d-%s", k.A, k.B)) 203 | }) 204 | var m part.Map[key, int] 205 | m = m.Set(key{1, "one"}, 1).Set(key{2, "two"}, 2).Set(key{3, "three"}, 3) 206 | 207 | bs, err := yaml.Marshal(m) 208 | require.NoError(t, err, "Marshal") 209 | 210 | var m2 part.Map[key, int] 211 | err = yaml.Unmarshal(bs, &m2) 212 | require.NoError(t, err, "Unmarshal") 213 | require.True(t, m.SlowEqual(m2), "SlowEqual") 214 | } 215 | 216 | func Benchmark_Uint64Map_Random(b *testing.B) { 217 | numItems := 1000 218 | keys := map[uint64]int{} 219 | for len(keys) < numItems { 220 | k := uint64(rand.Int64()) 221 | keys[k] = int(k) 222 | } 223 | for n := 0; n < b.N; n++ { 224 | var m part.Map[uint64, int] 225 | for k, v := range keys { 226 | m = m.Set(k, v) 227 | v2, ok := m.Get(k) 228 | if !ok || v != v2 { 229 | b.Fatalf("Get did not return value") 230 | } 231 | } 232 | } 233 | b.ReportMetric(float64(numItems*b.N)/b.Elapsed().Seconds(), "items/sec") 234 | } 235 | 236 | func Benchmark_Uint64Map_Sequential(b *testing.B) { 237 | numItems := 1000 238 | 239 | for n := 0; n < b.N; n++ { 240 | var m part.Map[uint64, int] 241 | for i := 0; i < numItems; i++ { 242 | k := uint64(i) 243 | m = m.Set(k, i) 244 | v, ok := m.Get(k) 245 | if !ok || v != i { 246 | b.Fatalf("Get did not return value") 247 | } 248 | } 249 | } 250 | b.ReportMetric(float64(numItems*b.N)/b.Elapsed().Seconds(), "items/sec") 251 | } 252 | -------------------------------------------------------------------------------- /part/ops.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package part 5 | 6 | // Ops is the common operations that can be performed with a Tree 7 | // or Txn. 8 | type Ops[T any] interface { 9 | // Len returns the number of objects in the tree. 10 | Len() int 11 | 12 | // Get fetches the value associated with the given key. 13 | // Returns the value, a watch channel (which is closed on 14 | // modification to the key) and boolean which is true if 15 | // value was found. 16 | Get(key []byte) (T, <-chan struct{}, bool) 17 | 18 | // Prefix returns an iterator for all objects that starts with the 19 | // given prefix, and a channel that closes when any objects matching 20 | // the given prefix are upserted or deleted. 21 | Prefix(key []byte) (*Iterator[T], <-chan struct{}) 22 | 23 | // LowerBound returns an iterator for all objects that have a 24 | // key equal or higher than the given 'key'. 25 | LowerBound(key []byte) *Iterator[T] 26 | 27 | // RootWatch returns a watch channel for the root of the tree. 28 | // Since this is the channel associated with the root, this closes 29 | // when there are any changes to the tree. 30 | RootWatch() <-chan struct{} 31 | 32 | // Iterator returns an iterator for all objects. 33 | Iterator() *Iterator[T] 34 | 35 | // PrintTree to the standard output. For debugging. 36 | PrintTree() 37 | } 38 | 39 | var ( 40 | _ Ops[int] = &Tree[int]{} 41 | _ Ops[int] = &Txn[int]{} 42 | ) 43 | -------------------------------------------------------------------------------- /part/quick_test.go: -------------------------------------------------------------------------------- 1 | package part_test 2 | 3 | import ( 4 | "slices" 5 | "testing" 6 | "testing/quick" 7 | 8 | "github.com/cilium/statedb/part" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | var quickConfig = &quick.Config{ 13 | // Use a higher count in order to hit the Node256 cases. 14 | MaxCount: 2000, 15 | } 16 | 17 | func TestQuick_InsertGetPrefix(t *testing.T) { 18 | tree := part.New[string]() 19 | insert := func(key, value string) any { 20 | _, _, tree = tree.Insert([]byte(key), value) 21 | return value 22 | } 23 | 24 | get := func(key, value string) any { 25 | val, _, _ := tree.Get([]byte(key)) 26 | if val != value { 27 | return val 28 | } 29 | 30 | iter, _ := tree.Prefix([]byte(key)) 31 | _, v, _ := iter.Next() 32 | return v 33 | } 34 | 35 | require.NoError(t, 36 | quick.CheckEqual(insert, get, quickConfig), 37 | ) 38 | } 39 | 40 | func TestQuick_IteratorReuse(t *testing.T) { 41 | tree := part.New[string]() 42 | 43 | iterate := func(key, value string, cloneFirst bool) bool { 44 | _, _, tree = tree.Insert([]byte(key), value) 45 | v, _, ok := tree.Get([]byte(key)) 46 | if !ok || value != v { 47 | return false 48 | } 49 | 50 | prefixIter, _ := tree.Prefix([]byte(key)) 51 | iterators := []*part.Iterator[string]{ 52 | tree.LowerBound([]byte(key)), 53 | prefixIter, 54 | } 55 | 56 | for _, iter := range iterators { 57 | iter2 := iter.Clone() 58 | 59 | collect := func(it *part.Iterator[string]) (out []string) { 60 | for k, v, ok := it.Next(); ok; k, v, ok = it.Next() { 61 | out = append(out, string(k)+"="+v) 62 | } 63 | return 64 | } 65 | 66 | var fst, snd []string 67 | if cloneFirst { 68 | snd = collect(iter2) 69 | fst = collect(iter) 70 | } else { 71 | fst = collect(iter) 72 | snd = collect(iter2) 73 | } 74 | 75 | if !slices.Equal(fst, snd) { 76 | return false 77 | } 78 | } 79 | return true 80 | } 81 | 82 | require.NoError(t, 83 | quick.Check(iterate, quickConfig), 84 | ) 85 | } 86 | 87 | func TestQuick_Delete(t *testing.T) { 88 | tree := part.New[string]() 89 | 90 | do := func(key, value string, delete bool) bool { 91 | _, _, tree = tree.Insert([]byte(key), value) 92 | treeAfterInsert := tree 93 | v, watch, ok := tree.Get([]byte(key)) 94 | if !ok || v != value { 95 | t.Logf("value not in tree after insert") 96 | return false 97 | } 98 | 99 | // delete some of the time to construct different variations of trees. 100 | if delete { 101 | _, _, tree = tree.Delete([]byte(key)) 102 | _, _, ok := tree.Get([]byte(key)) 103 | if ok { 104 | t.Logf("value exists after delete") 105 | return false 106 | } 107 | 108 | _, _, ok = treeAfterInsert.Get([]byte(key)) 109 | if !ok { 110 | t.Logf("value deleted from original") 111 | } 112 | 113 | // Check that watch channel closed. 114 | select { 115 | case <-watch: 116 | default: 117 | t.Logf("watch channel not closed") 118 | return false 119 | } 120 | } 121 | return true 122 | } 123 | 124 | require.NoError(t, quick.Check(do, quickConfig)) 125 | } 126 | 127 | func TestQuick_ClosedWatch(t *testing.T) { 128 | tree := part.New[string]() 129 | insert := func(key, value string) bool { 130 | _, _, tree = tree.Insert([]byte(key), value) 131 | treeAfterInsert := tree 132 | 133 | val, watch, ok := tree.Get([]byte(key)) 134 | if !ok { 135 | return false 136 | } 137 | if val != value { 138 | return false 139 | } 140 | 141 | select { 142 | case <-watch: 143 | return false 144 | default: 145 | } 146 | 147 | // Changing the key makes the channel close. 148 | _, _, tree = tree.Insert([]byte(key), "x") 149 | select { 150 | case <-watch: 151 | default: 152 | t.Logf("channel not closed!") 153 | return false 154 | } 155 | 156 | // Original tree unaffected. 157 | val, _, ok = treeAfterInsert.Get([]byte(key)) 158 | if !ok || val != value { 159 | t.Logf("original changed!") 160 | return false 161 | } 162 | 163 | val, _, ok = tree.Get([]byte(key)) 164 | if !ok || val != "x" { 165 | t.Logf("new tree does not have x!") 166 | return false 167 | } 168 | 169 | return true 170 | } 171 | 172 | require.NoError(t, quick.Check(insert, quickConfig)) 173 | } 174 | -------------------------------------------------------------------------------- /part/registry.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package part 5 | 6 | import ( 7 | "encoding/binary" 8 | "fmt" 9 | "math" 10 | "reflect" 11 | "sync" 12 | "unicode/utf8" 13 | ) 14 | 15 | // keyTypeRegistry is a registry of functions to convert to/from keys (of type K). 16 | // This mechanism enables use of zero value and JSON marshalling and unmarshalling 17 | // with Map and Set. 18 | var keyTypeRegistry sync.Map // map[reflect.Type]func(K) []byte 19 | 20 | // RegisterKeyType registers a new key type to be used with the Map and Set types. 21 | // Intended to be called from init() functions. 22 | // For Set-only usage only the [bytesFromKey] function is needed. 23 | func RegisterKeyType[K any](bytesFromKey func(K) []byte) { 24 | keyType := reflect.TypeFor[K]() 25 | keyTypeRegistry.Store( 26 | keyType, 27 | bytesFromKey, 28 | ) 29 | } 30 | 31 | func lookupKeyType[K any]() func(K) []byte { 32 | keyType := reflect.TypeFor[K]() 33 | funcAny, ok := keyTypeRegistry.Load(keyType) 34 | if !ok { 35 | panic(fmt.Sprintf("Key type %q not registered with part.RegisterMapKeyType()", keyType)) 36 | } 37 | return funcAny.(func(K) []byte) 38 | } 39 | 40 | func init() { 41 | // Register common key types. 42 | RegisterKeyType[string](func(s string) []byte { return []byte(s) }) 43 | RegisterKeyType[[]byte](func(b []byte) []byte { return b }) 44 | RegisterKeyType[byte](func(b byte) []byte { return []byte{b} }) 45 | RegisterKeyType[rune](func(r rune) []byte { return utf8.AppendRune(nil, r) }) 46 | RegisterKeyType[complex128](func(c complex128) []byte { 47 | buf := make([]byte, 0, 16) 48 | buf = binary.BigEndian.AppendUint64(buf, math.Float64bits(real(c))) 49 | buf = binary.BigEndian.AppendUint64(buf, math.Float64bits(imag(c))) 50 | return buf 51 | }) 52 | RegisterKeyType[float64](func(x float64) []byte { return binary.BigEndian.AppendUint64(nil, math.Float64bits(x)) }) 53 | RegisterKeyType[float32](func(x float32) []byte { return binary.BigEndian.AppendUint32(nil, math.Float32bits(x)) }) 54 | RegisterKeyType[uint64](func(x uint64) []byte { return binary.BigEndian.AppendUint64(nil, x) }) 55 | RegisterKeyType[uint32](func(x uint32) []byte { return binary.BigEndian.AppendUint32(nil, x) }) 56 | RegisterKeyType[uint16](func(x uint16) []byte { return binary.BigEndian.AppendUint16(nil, x) }) 57 | RegisterKeyType[int64](func(x int64) []byte { return binary.BigEndian.AppendUint64(nil, uint64(x)) }) 58 | RegisterKeyType[int32](func(x int32) []byte { return binary.BigEndian.AppendUint32(nil, uint32(x)) }) 59 | RegisterKeyType[int16](func(x int16) []byte { return binary.BigEndian.AppendUint16(nil, uint16(x)) }) 60 | RegisterKeyType[int](func(x int) []byte { return binary.BigEndian.AppendUint64(nil, uint64(x)) }) 61 | 62 | var ( 63 | trueBytes = []byte{'T'} 64 | falseBytes = []byte{'F'} 65 | ) 66 | RegisterKeyType[bool](func(b bool) []byte { 67 | if b { 68 | return trueBytes 69 | } else { 70 | return falseBytes 71 | } 72 | }) 73 | 74 | } 75 | -------------------------------------------------------------------------------- /part/set.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package part 5 | 6 | import ( 7 | "bytes" 8 | "encoding/json" 9 | "fmt" 10 | "iter" 11 | "slices" 12 | 13 | "gopkg.in/yaml.v3" 14 | ) 15 | 16 | // Set is a persistent (immutable) set of values. A Set can be 17 | // defined for any type for which a byte slice key can be derived. 18 | // 19 | // A zero value Set[T] can be used provided that the conversion 20 | // function for T have been registered with RegisterKeyType. 21 | // For Set-only use only [bytesFromKey] needs to be defined. 22 | type Set[T any] struct { 23 | toBytes func(T) []byte 24 | tree *Tree[T] 25 | } 26 | 27 | // NewSet creates a new set of T. 28 | // The value type T must be registered with RegisterKeyType. 29 | func NewSet[T any](values ...T) Set[T] { 30 | s := Set[T]{tree: New[T](RootOnlyWatch)} 31 | s.toBytes = lookupKeyType[T]() 32 | if len(values) > 0 { 33 | txn := s.tree.Txn() 34 | for _, v := range values { 35 | txn.Insert(s.toBytes(v), v) 36 | } 37 | s.tree = txn.CommitOnly() 38 | } 39 | return s 40 | } 41 | 42 | // Set a value. Returns a new set. Original is unchanged. 43 | func (s Set[T]) Set(v T) Set[T] { 44 | if s.tree == nil { 45 | return NewSet(v) 46 | } 47 | txn := s.tree.Txn() 48 | txn.Insert(s.toBytes(v), v) 49 | s.tree = txn.CommitOnly() // As Set is passed by value we can just modify it. 50 | return s 51 | } 52 | 53 | // Delete returns a new set without the value. The original 54 | // set is unchanged. 55 | func (s Set[T]) Delete(v T) Set[T] { 56 | if s.tree == nil { 57 | return s 58 | } 59 | txn := s.tree.Txn() 60 | txn.Delete(s.toBytes(v)) 61 | s.tree = txn.CommitOnly() 62 | return s 63 | } 64 | 65 | // Has returns true if the set has the value. 66 | func (s Set[T]) Has(v T) bool { 67 | if s.tree == nil { 68 | return false 69 | } 70 | _, _, found := s.tree.Get(s.toBytes(v)) 71 | return found 72 | } 73 | 74 | // All returns an iterator for all values. 75 | func (s Set[T]) All() iter.Seq[T] { 76 | if s.tree == nil { 77 | return toSeq[T](nil) 78 | } 79 | return toSeq(s.tree.Iterator()) 80 | } 81 | 82 | // Union returns a set that is the union of the values 83 | // in the input sets. 84 | func (s Set[T]) Union(s2 Set[T]) Set[T] { 85 | if s2.tree == nil { 86 | return s 87 | } 88 | if s.tree == nil { 89 | return s2 90 | } 91 | txn := s.tree.Txn() 92 | iter := s2.tree.Iterator() 93 | for k, v, ok := iter.Next(); ok; k, v, ok = iter.Next() { 94 | txn.Insert(k, v) 95 | } 96 | s.tree = txn.CommitOnly() 97 | return s 98 | } 99 | 100 | // Difference returns a set with values that only 101 | // appear in the first set. 102 | func (s Set[T]) Difference(s2 Set[T]) Set[T] { 103 | if s.tree == nil || s2.tree == nil { 104 | return s 105 | } 106 | 107 | txn := s.tree.Txn() 108 | iter := s2.tree.Iterator() 109 | for k, _, ok := iter.Next(); ok; k, _, ok = iter.Next() { 110 | txn.Delete(k) 111 | } 112 | s.tree = txn.CommitOnly() 113 | return s 114 | } 115 | 116 | // Len returns the number of values in the set. 117 | func (s Set[T]) Len() int { 118 | if s.tree == nil { 119 | return 0 120 | } 121 | return s.tree.size 122 | } 123 | 124 | // Equal returns true if the two sets contain the equal keys. 125 | func (s Set[T]) Equal(other Set[T]) bool { 126 | switch { 127 | case s.tree == nil && other.tree == nil: 128 | return true 129 | case s.Len() != other.Len(): 130 | return false 131 | default: 132 | iter1 := s.tree.Iterator() 133 | iter2 := other.tree.Iterator() 134 | for { 135 | k1, _, ok := iter1.Next() 136 | if !ok { 137 | break 138 | } 139 | k2, _, _ := iter2.Next() 140 | // Equal lengths, no need to check 'ok' for 'iter2'. 141 | if !bytes.Equal(k1, k2) { 142 | return false 143 | } 144 | } 145 | return true 146 | } 147 | } 148 | 149 | // ToBytesFunc returns the function to extract the key from 150 | // the element type. Useful for utilities that are interested 151 | // in the key. 152 | func (s Set[T]) ToBytesFunc() func(T) []byte { 153 | return s.toBytes 154 | } 155 | 156 | func (s Set[T]) MarshalJSON() ([]byte, error) { 157 | if s.tree == nil { 158 | return []byte("[]"), nil 159 | } 160 | var b bytes.Buffer 161 | b.WriteRune('[') 162 | iter := s.tree.Iterator() 163 | _, v, ok := iter.Next() 164 | for ok { 165 | bs, err := json.Marshal(v) 166 | if err != nil { 167 | return nil, err 168 | } 169 | b.Write(bs) 170 | _, v, ok = iter.Next() 171 | if ok { 172 | b.WriteRune(',') 173 | } 174 | } 175 | b.WriteRune(']') 176 | return b.Bytes(), nil 177 | } 178 | 179 | func (s *Set[T]) UnmarshalJSON(data []byte) error { 180 | dec := json.NewDecoder(bytes.NewReader(data)) 181 | t, err := dec.Token() 182 | if err != nil { 183 | return err 184 | } 185 | if d, ok := t.(json.Delim); !ok || d != '[' { 186 | return fmt.Errorf("%T.UnmarshalJSON: expected '[' got %v", s, t) 187 | } 188 | 189 | if s.tree == nil { 190 | *s = NewSet[T]() 191 | } 192 | txn := s.tree.Txn() 193 | 194 | for dec.More() { 195 | var x T 196 | err := dec.Decode(&x) 197 | if err != nil { 198 | return err 199 | } 200 | txn.Insert(s.toBytes(x), x) 201 | } 202 | s.tree = txn.CommitOnly() 203 | 204 | t, err = dec.Token() 205 | if err != nil { 206 | return err 207 | } 208 | if d, ok := t.(json.Delim); !ok || d != ']' { 209 | return fmt.Errorf("%T.UnmarshalJSON: expected ']' got %v", s, t) 210 | } 211 | return nil 212 | } 213 | 214 | func (s Set[T]) MarshalYAML() (any, error) { 215 | // TODO: Once yaml.v3 supports iter.Seq, drop the Collect(). 216 | return slices.Collect(s.All()), nil 217 | } 218 | 219 | func (s *Set[T]) UnmarshalYAML(value *yaml.Node) error { 220 | if value.Kind != yaml.SequenceNode { 221 | return fmt.Errorf("%T.UnmarshalYAML: expected sequence", s) 222 | } 223 | 224 | if s.tree == nil { 225 | *s = NewSet[T]() 226 | } 227 | txn := s.tree.Txn() 228 | 229 | for _, e := range value.Content { 230 | var v T 231 | if err := e.Decode(&v); err != nil { 232 | return err 233 | } 234 | txn.Insert(s.toBytes(v), v) 235 | } 236 | s.tree = txn.CommitOnly() 237 | return nil 238 | } 239 | 240 | func toSeq[T any](iter *Iterator[T]) iter.Seq[T] { 241 | return func(yield func(T) bool) { 242 | if iter == nil { 243 | return 244 | } 245 | iter = iter.Clone() 246 | for _, x, ok := iter.Next(); ok; _, x, ok = iter.Next() { 247 | if !yield(x) { 248 | break 249 | } 250 | } 251 | } 252 | } 253 | -------------------------------------------------------------------------------- /part/set_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package part_test 5 | 6 | import ( 7 | "encoding/json" 8 | "slices" 9 | "testing" 10 | 11 | "github.com/cilium/statedb/part" 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | "gopkg.in/yaml.v3" 15 | ) 16 | 17 | func TestStringSet(t *testing.T) { 18 | var s part.Set[string] 19 | 20 | assert.False(t, s.Has("nothing"), "Has nothing") 21 | 22 | s = s.Set("foo") 23 | assert.True(t, s.Has("foo"), "Has foo") 24 | 25 | count := 0 26 | for v := range s.All() { 27 | assert.Equal(t, "foo", v) 28 | count++ 29 | } 30 | assert.Equal(t, 1, count) 31 | 32 | s2 := part.NewSet("bar") 33 | 34 | s3 := s.Union(s2) 35 | assert.False(t, s.Has("bar"), "s has no bar") 36 | assert.False(t, s2.Has("foo"), "s2 has no foo") 37 | assert.True(t, s3.Has("foo"), "s3 has foo") 38 | assert.True(t, s3.Has("bar"), "s3 has bar") 39 | 40 | values := slices.Collect(s3.All()) 41 | assert.ElementsMatch(t, []string{"foo", "bar"}, values) 42 | 43 | s4 := s3.Difference(s2) 44 | assert.False(t, s4.Has("bar"), "s4 has no bar") 45 | assert.True(t, s4.Has("foo"), "s4 has foo") 46 | 47 | assert.Equal(t, 2, s3.Len()) 48 | 49 | s5 := s3.Delete("foo") 50 | assert.True(t, s3.Has("foo"), "s3 has foo") 51 | assert.False(t, s5.Has("foo"), "s3 has no foo") 52 | 53 | // Deleting again does the same. 54 | s5 = s3.Delete("foo") 55 | assert.False(t, s5.Has("foo"), "s3 has no foo") 56 | 57 | assert.Equal(t, 2, s3.Len()) 58 | assert.Equal(t, 1, s5.Len()) 59 | } 60 | 61 | func TestSetJSON(t *testing.T) { 62 | s := part.NewSet("foo", "bar", "baz") 63 | 64 | bs, err := json.Marshal(s) 65 | require.NoError(t, err, "Marshal") 66 | 67 | var s2 part.Set[string] 68 | err = json.Unmarshal(bs, &s2) 69 | require.NoError(t, err, "Unmarshal") 70 | require.True(t, s.Equal(s2), "Equal") 71 | } 72 | 73 | func TestSetYAML(t *testing.T) { 74 | s := part.NewSet("foo", "bar", "baz") 75 | 76 | bs, err := yaml.Marshal(s) 77 | require.NoError(t, err, "Marshal") 78 | require.Equal(t, "- bar\n- baz\n- foo\n", string(bs)) 79 | 80 | var s2 part.Set[string] 81 | err = yaml.Unmarshal(bs, &s2) 82 | require.NoError(t, err, "Unmarshal") 83 | require.True(t, s.Equal(s2), "Equal") 84 | 85 | var empty part.Set[string] 86 | bs, err = yaml.Marshal(empty) 87 | require.NoError(t, err, "Unmarshal") 88 | require.Equal(t, "[]\n", string(bs)) 89 | require.True(t, s.Equal(s2), "Equal") 90 | } 91 | -------------------------------------------------------------------------------- /part/tree.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package part 5 | 6 | // Tree is a persistent (immutable) adaptive radix tree. It supports 7 | // map-like operations on values keyed by []byte and additionally 8 | // prefix searching and lower bound searching. Each node in the tree 9 | // has an associated channel that is closed when that node is mutated. 10 | // This allows watching any part of the tree (any prefix) for changes. 11 | type Tree[T any] struct { 12 | opts *options 13 | root *header[T] 14 | size int // the number of objects in the tree 15 | } 16 | 17 | // New constructs a new tree. 18 | func New[T any](opts ...Option) *Tree[T] { 19 | var o options 20 | for _, opt := range opts { 21 | opt(&o) 22 | } 23 | return &Tree[T]{ 24 | root: newNode4[T](), 25 | size: 0, 26 | opts: &o, 27 | } 28 | } 29 | 30 | type Option func(*options) 31 | 32 | // RootOnlyWatch sets the tree to only have a watch channel on the root 33 | // node. This improves the speed at the cost of having a much more coarse 34 | // grained notifications. 35 | func RootOnlyWatch(o *options) { o.rootOnlyWatch = true } 36 | 37 | // Txn constructs a new transaction against the tree. Transactions 38 | // enable efficient large mutations of the tree by caching cloned 39 | // nodes. 40 | func (t *Tree[T]) Txn() *Txn[T] { 41 | txn := &Txn[T]{ 42 | Tree: *t, 43 | watches: make(map[chan struct{}]struct{}), 44 | } 45 | return txn 46 | } 47 | 48 | // Len returns the number of objects in the tree. 49 | func (t *Tree[T]) Len() int { 50 | return t.size 51 | } 52 | 53 | // Get fetches the value associated with the given key. 54 | // Returns the value, a watch channel (which is closed on 55 | // modification to the key) and boolean which is true if 56 | // value was found. 57 | func (t *Tree[T]) Get(key []byte) (T, <-chan struct{}, bool) { 58 | value, watch, ok := search(t.root, key) 59 | if t.opts.rootOnlyWatch { 60 | watch = t.root.watch 61 | } 62 | return value, watch, ok 63 | } 64 | 65 | // Prefix returns an iterator for all objects that starts with the 66 | // given prefix, and a channel that closes when any objects matching 67 | // the given prefix are upserted or deleted. 68 | func (t *Tree[T]) Prefix(prefix []byte) (*Iterator[T], <-chan struct{}) { 69 | iter, watch := prefixSearch(t.root, prefix) 70 | if t.opts.rootOnlyWatch { 71 | watch = t.root.watch 72 | } 73 | return iter, watch 74 | } 75 | 76 | // RootWatch returns a watch channel for the root of the tree. 77 | // Since this is the channel associated with the root, this closes 78 | // when there are any changes to the tree. 79 | func (t *Tree[T]) RootWatch() <-chan struct{} { 80 | return t.root.watch 81 | } 82 | 83 | // LowerBound returns an iterator for all keys that have a value 84 | // equal to or higher than 'key'. 85 | func (t *Tree[T]) LowerBound(key []byte) *Iterator[T] { 86 | return lowerbound(t.root, key) 87 | } 88 | 89 | // Insert inserts the key into the tree with the given value. 90 | // Returns the old value if it exists and a new tree. 91 | func (t *Tree[T]) Insert(key []byte, value T) (old T, hadOld bool, tree *Tree[T]) { 92 | txn := t.Txn() 93 | old, hadOld = txn.Insert(key, value) 94 | tree = txn.Commit() 95 | return 96 | } 97 | 98 | // Modify a value in the tree. If the key does not exist the modify 99 | // function is called with the zero value for T. It is up to the 100 | // caller to not mutate the value in-place and to return a clone. 101 | // Returns the old value if it exists. 102 | func (t *Tree[T]) Modify(key []byte, mod func(T) T) (old T, hadOld bool, tree *Tree[T]) { 103 | txn := t.Txn() 104 | old, hadOld = txn.Modify(key, mod) 105 | tree = txn.Commit() 106 | return 107 | } 108 | 109 | // Delete the given key from the tree. 110 | // Returns the old value if it exists and the new tree. 111 | func (t *Tree[T]) Delete(key []byte) (old T, hadOld bool, tree *Tree[T]) { 112 | txn := t.Txn() 113 | old, hadOld = txn.Delete(key) 114 | tree = txn.Commit() 115 | return 116 | } 117 | 118 | // Iterator returns an iterator for all objects. 119 | func (t *Tree[T]) Iterator() *Iterator[T] { 120 | return newIterator[T](t.root) 121 | } 122 | 123 | // PrintTree to the standard output. For debugging. 124 | func (t *Tree[T]) PrintTree() { 125 | t.root.printTree(0) 126 | } 127 | 128 | type options struct { 129 | rootOnlyWatch bool 130 | } 131 | -------------------------------------------------------------------------------- /reconciler/benchmark/.gitignore: -------------------------------------------------------------------------------- 1 | benchmark 2 | -------------------------------------------------------------------------------- /reconciler/benchmark/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package main 5 | 6 | import ( 7 | "context" 8 | "flag" 9 | "fmt" 10 | "iter" 11 | "log" 12 | "log/slog" 13 | "os" 14 | "runtime" 15 | "runtime/pprof" 16 | "sync/atomic" 17 | "time" 18 | 19 | "github.com/cilium/hive" 20 | "github.com/cilium/hive/cell" 21 | "github.com/cilium/hive/job" 22 | "github.com/cilium/statedb" 23 | "github.com/cilium/statedb/index" 24 | "github.com/cilium/statedb/reconciler" 25 | "golang.org/x/time/rate" 26 | ) 27 | 28 | var logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ 29 | Level: slog.LevelError, 30 | })) 31 | 32 | var cpuprofile = flag.String("cpuprofile", "", "write cpu profile to `file`") 33 | var memprofile = flag.String("memprofile", "", "write memory profile to `file`") 34 | var numObjects = flag.Int("objects", 1000000, "number of objects to create") 35 | var batchSize = flag.Int("batchsize", 1000, "batch size for writes") 36 | var incrBatchSize = flag.Int("incrbatchsize", 1000, "maximum batch size for incremental reconciliation") 37 | var quiet = flag.Bool("quiet", false, "quiet output for CI") 38 | 39 | type testObject struct { 40 | id uint64 41 | status reconciler.Status 42 | } 43 | 44 | func (t *testObject) GetStatus() reconciler.Status { 45 | return t.status 46 | } 47 | 48 | func (t *testObject) SetStatus(status reconciler.Status) *testObject { 49 | t.status = status 50 | return t 51 | } 52 | 53 | func (t *testObject) Clone() *testObject { 54 | t2 := *t 55 | return &t2 56 | } 57 | 58 | type mockOps struct { 59 | numUpdates atomic.Int32 60 | } 61 | 62 | // Delete implements reconciler.Operations. 63 | func (mt *mockOps) Delete(ctx context.Context, txn statedb.ReadTxn, rev statedb.Revision, obj *testObject) error { 64 | return nil 65 | } 66 | 67 | // Prune implements reconciler.Operations. 68 | func (mt *mockOps) Prune(ctx context.Context, txn statedb.ReadTxn, objects iter.Seq2[*testObject, statedb.Revision]) error { 69 | return nil 70 | } 71 | 72 | // Update implements reconciler.Operations. 73 | func (mt *mockOps) Update(ctx context.Context, txn statedb.ReadTxn, rev statedb.Revision, obj *testObject) error { 74 | mt.numUpdates.Add(1) 75 | return nil 76 | } 77 | 78 | var _ reconciler.Operations[*testObject] = &mockOps{} 79 | 80 | var idIndex = statedb.Index[*testObject, uint64]{ 81 | Name: "id", 82 | FromObject: func(t *testObject) index.KeySet { 83 | return index.NewKeySet(index.Uint64(t.id)) 84 | }, 85 | FromKey: index.Uint64, 86 | Unique: true, 87 | } 88 | 89 | func main() { 90 | var memBefore runtime.MemStats 91 | runtime.ReadMemStats(&memBefore) 92 | 93 | flag.Parse() 94 | if *cpuprofile != "" { 95 | f, err := os.Create(*cpuprofile) 96 | if err != nil { 97 | log.Fatal("could not create CPU profile: ", err) 98 | } 99 | defer f.Close() // error handling omitted for example 100 | if err := pprof.StartCPUProfile(f); err != nil { 101 | log.Fatal("could not start CPU profile: ", err) 102 | } 103 | defer pprof.StopCPUProfile() 104 | } 105 | 106 | var ( 107 | mt = &mockOps{} 108 | db *statedb.DB 109 | ) 110 | 111 | testObjects, err := statedb.NewTable("test-objects", idIndex) 112 | if err != nil { 113 | panic(err) 114 | } 115 | 116 | hive := hive.New( 117 | cell.SimpleHealthCell, 118 | statedb.Cell, 119 | job.Cell, 120 | 121 | cell.Module( 122 | "test", 123 | "Test", 124 | 125 | cell.Invoke(func(db_ *statedb.DB) error { 126 | db = db_ 127 | return db.RegisterTable(testObjects) 128 | }), 129 | cell.Provide( 130 | func() (*mockOps, reconciler.Operations[*testObject]) { 131 | return mt, mt 132 | }, 133 | ), 134 | cell.Invoke(func(params reconciler.Params) error { 135 | _, err := reconciler.Register( 136 | params, 137 | 138 | testObjects, 139 | (*testObject).Clone, 140 | (*testObject).SetStatus, 141 | (*testObject).GetStatus, 142 | mt, 143 | nil, 144 | 145 | reconciler.WithRoundLimits( 146 | *incrBatchSize, 147 | rate.NewLimiter(1000.0, 10), 148 | ), 149 | ) 150 | return err 151 | }), 152 | ), 153 | ) 154 | 155 | err = hive.Start(logger, context.TODO()) 156 | if err != nil { 157 | panic(err) 158 | } 159 | 160 | start := time.Now() 161 | 162 | // Create objects in batches to allow the reconciler to start working 163 | // on them while they're added. 164 | id := uint64(0) 165 | batches := int(*numObjects / *batchSize) 166 | for b := 0; b < batches; b++ { 167 | if !*quiet { 168 | fmt.Printf("\rInserting batch %d/%d ...", b+1, batches) 169 | } 170 | wtxn := db.WriteTxn(testObjects) 171 | for j := 0; j < *batchSize; j++ { 172 | testObjects.Insert(wtxn, &testObject{ 173 | id: id, 174 | status: reconciler.StatusPending(), 175 | }) 176 | id++ 177 | } 178 | wtxn.Commit() 179 | } 180 | 181 | if !*quiet { 182 | fmt.Printf("\nWaiting for reconciliation to finish ...\n\n") 183 | } 184 | 185 | // Wait for all to be reconciled by waiting for the last added objects to be marked 186 | // reconciled. This only works here since none of the operations fail. 187 | for { 188 | obj, _, watch, ok := testObjects.GetWatch(db.ReadTxn(), idIndex.Query(id-1)) 189 | if ok && obj.status.Kind == reconciler.StatusKindDone { 190 | break 191 | } 192 | <-watch 193 | } 194 | 195 | end := time.Now() 196 | duration := end.Sub(start) 197 | 198 | timePerObject := float64(duration) / float64(*numObjects) 199 | objsPerSecond := float64(time.Second) / timePerObject 200 | 201 | // Check that all objects were updated. 202 | if mt.numUpdates.Load() != int32(*numObjects) { 203 | log.Fatalf("expected %d updates, but only saw %d", *numObjects, mt.numUpdates.Load()) 204 | } 205 | 206 | // Check that all statuses are correctly set. 207 | for obj := range testObjects.All(db.ReadTxn()) { 208 | if obj.status.Kind != reconciler.StatusKindDone { 209 | log.Fatalf("Object with unexpected status: %#v", obj) 210 | } 211 | } 212 | 213 | if *memprofile != "" { 214 | f, err := os.Create(*memprofile) 215 | if err != nil { 216 | log.Fatal("could not create memory profile: ", err) 217 | } 218 | defer f.Close() // error handling omitted for example 219 | runtime.GC() // get up-to-date statistics 220 | if err := pprof.WriteHeapProfile(f); err != nil { 221 | log.Fatal("could not write memory profile: ", err) 222 | } 223 | } 224 | 225 | runtime.GC() 226 | var memAfter runtime.MemStats 227 | runtime.ReadMemStats(&memAfter) 228 | 229 | err = hive.Stop(logger, context.TODO()) 230 | if err != nil { 231 | panic(err) 232 | } 233 | 234 | fmt.Printf("%d objects reconciled in %.2f seconds (batch size %d)\n", 235 | *numObjects, float64(duration)/float64(time.Second), *batchSize) 236 | fmt.Printf("Throughput %.2f objects per second\n", objsPerSecond) 237 | fmt.Printf("Allocated %d objects, %dkB bytes, %dkB bytes still in use\n", 238 | memAfter.HeapObjects-memBefore.HeapObjects, 239 | (memAfter.HeapAlloc-memBefore.HeapAlloc)/1024, 240 | (memAfter.HeapInuse-memBefore.HeapInuse)/1024) 241 | 242 | } 243 | -------------------------------------------------------------------------------- /reconciler/benchmark/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -eu 3 | go build . 4 | 5 | for batchSize in 100 1000 10000; do 6 | for incrBatchSize in 1000 5000; do 7 | echo "batchSize: $batchSize, incrBatchSize: $incrBatchSize" 8 | go run . -objects=100000 -batchsize=$batchSize -incrbatchsize=$incrBatchSize 9 | echo "----------------------------------------------------" 10 | done 11 | done 12 | -------------------------------------------------------------------------------- /reconciler/builder.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package reconciler 5 | 6 | import ( 7 | "time" 8 | 9 | "github.com/cilium/hive/job" 10 | "github.com/cilium/statedb" 11 | "github.com/cilium/statedb/index" 12 | "golang.org/x/time/rate" 13 | ) 14 | 15 | // Register creates a new reconciler and registers it to the application lifecycle. 16 | // 17 | // The setStatus etc. functions are passed in as arguments rather than requiring 18 | // the object to implement them via interface as this allows constructing multiple 19 | // reconcilers for a single object by having multiple status fields and different 20 | // functions for manipulating them. 21 | func Register[Obj comparable]( 22 | // General dependencies of the reconciler. 23 | params Params, 24 | // The table to reconcile 25 | table statedb.RWTable[Obj], 26 | 27 | // Function for cloning the object. 28 | clone func(Obj) Obj, 29 | 30 | // Function for setting the status. 31 | setStatus func(Obj, Status) Obj, 32 | 33 | // Function for getting the status. 34 | getStatus func(Obj) Status, 35 | 36 | // Reconciliation operations 37 | ops Operations[Obj], 38 | 39 | // (Optional) batch operations. Set to nil if not available. 40 | batchOps BatchOperations[Obj], 41 | 42 | // zero or more options to override defaults. 43 | options ...Option, 44 | ) (Reconciler[Obj], error) { 45 | cfg := config[Obj]{ 46 | Table: table, 47 | GetObjectStatus: getStatus, 48 | SetObjectStatus: setStatus, 49 | CloneObject: clone, 50 | Operations: ops, 51 | BatchOperations: batchOps, 52 | options: defaultOptions(), 53 | } 54 | for _, opt := range options { 55 | opt(&cfg.options) 56 | } 57 | 58 | if cfg.Metrics == nil { 59 | if params.DefaultMetrics == nil { 60 | cfg.Metrics = NewUnpublishedExpVarMetrics() 61 | } else { 62 | cfg.Metrics = params.DefaultMetrics 63 | } 64 | } 65 | 66 | if err := cfg.validate(); err != nil { 67 | return nil, err 68 | } 69 | 70 | idx := cfg.Table.PrimaryIndexer() 71 | objectToKey := func(o any) index.Key { 72 | return idx.ObjectToKey(o.(Obj)) 73 | } 74 | r := &reconciler[Obj]{ 75 | Params: params, 76 | config: cfg, 77 | retries: newRetries(cfg.RetryBackoffMinDuration, cfg.RetryBackoffMaxDuration, objectToKey), 78 | externalPruneTrigger: make(chan struct{}, 1), 79 | primaryIndexer: idx, 80 | } 81 | 82 | g := params.Jobs.NewGroup(params.Health, params.Lifecycle) 83 | 84 | g.Add(job.OneShot("reconcile", r.reconcileLoop)) 85 | if r.config.RefreshInterval > 0 { 86 | g.Add(job.OneShot("refresh", r.refreshLoop)) 87 | } 88 | return r, nil 89 | } 90 | 91 | // Option for the reconciler 92 | type Option func(opts *options) 93 | 94 | // WithMetrics sets the [Metrics] instance to use with this reconciler. 95 | // The metrics capture the duration of operations during incremental and 96 | // full reconcilation and the errors that occur during either. 97 | // 98 | // If this option is not used, then the default metrics instance is used. 99 | func WithMetrics(m Metrics) Option { 100 | return func(opts *options) { 101 | opts.Metrics = m 102 | } 103 | } 104 | 105 | // WithPruning enables periodic pruning (calls to Prune() operation) 106 | // [interval] is the interval at which Prune() is called to prune 107 | // unexpected objects in the target system. 108 | // Prune() will not be called before the table has been fully initialized 109 | // (Initialized() returns true). 110 | // A single Prune() can be forced via the [Reconciler.Prune] method regardless 111 | // if pruning has been enabled. 112 | // 113 | // Pruning is enabled by default. See [config.go] for the default interval. 114 | func WithPruning(interval time.Duration) Option { 115 | return func(opts *options) { 116 | opts.PruneInterval = interval 117 | } 118 | } 119 | 120 | // WithoutPruning disabled periodic pruning. 121 | func WithoutPruning() Option { 122 | return func(opts *options) { 123 | opts.PruneInterval = 0 124 | } 125 | } 126 | 127 | // WithRefreshing enables periodic refreshes of objects. 128 | // [interval] is the interval at which the objects are refreshed, 129 | // e.g. how often Update() should be called to refresh an object even 130 | // when it has not changed. This is implemented by periodically setting 131 | // all objects that have not been updated for [RefreshInterval] or longer 132 | // as pending. 133 | // [limiter] is the rate-limiter for controlling the rate at which the 134 | // objects are marked pending. 135 | // 136 | // Refreshing is disabled by default. 137 | func WithRefreshing(interval time.Duration, limiter *rate.Limiter) Option { 138 | return func(opts *options) { 139 | opts.RefreshInterval = interval 140 | opts.RefreshRateLimiter = limiter 141 | } 142 | } 143 | 144 | // WithRetry sets the minimum and maximum amount of time to wait before 145 | // retrying a failed Update() or Delete() operation on an object. 146 | // The retry wait time for an object will increase exponentially on 147 | // subsequent failures until [maxBackoff] is reached. 148 | func WithRetry(minBackoff, maxBackoff time.Duration) Option { 149 | return func(opts *options) { 150 | opts.RetryBackoffMinDuration = minBackoff 151 | opts.RetryBackoffMaxDuration = maxBackoff 152 | } 153 | } 154 | 155 | // WithRoundLimits sets the reconciliation round size and rate limit. 156 | // [numObjects] limits how many objects are reconciled per round before 157 | // updating their status. A high number will delay status updates and increase 158 | // latency for those watching the object reconciliation status. A low value 159 | // increases the overhead of the status committing and reduces effectiveness 160 | // of the batch operations (smaller batch sizes). 161 | // [limiter] is used to limit the number of rounds per second to allow a larger 162 | // batch to build up and to avoid reconciliation of intermediate object states. 163 | func WithRoundLimits(numObjects int, limiter *rate.Limiter) Option { 164 | return func(opts *options) { 165 | opts.IncrementalRoundSize = numObjects 166 | opts.RateLimiter = limiter 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /reconciler/config.go: -------------------------------------------------------------------------------- 1 | package reconciler 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/cilium/statedb" 8 | "golang.org/x/time/rate" 9 | ) 10 | 11 | func defaultOptions() options { 12 | return options{ 13 | Metrics: nil, // use DefaultMetrics 14 | 15 | // Refresh objects every 30 minutes at a rate of 100 per second. 16 | RefreshInterval: 30 * time.Minute, 17 | RefreshRateLimiter: rate.NewLimiter(100.0, 1), 18 | 19 | // Prune when initialized and then once an hour. 20 | PruneInterval: time.Hour, 21 | 22 | // Retry failed operations with exponential backoff from 100ms to 1min. 23 | RetryBackoffMinDuration: 100 * time.Millisecond, 24 | RetryBackoffMaxDuration: time.Minute, 25 | 26 | // Reconcile 100 rounds per second * 1000 yielding maximum rate of 27 | // 100k objects per second. 28 | IncrementalRoundSize: 1000, 29 | RateLimiter: rate.NewLimiter(1000.0, 1), 30 | } 31 | } 32 | 33 | type options struct { 34 | // Metrics to use with this reconciler. The metrics capture the duration 35 | // of operations during incremental and full reconcilation and the errors 36 | // that occur during either. 37 | // 38 | // If nil, then the default metrics are used via Params. 39 | // A simple implementation of metrics based on the expvar package come 40 | // with the reconciler and a custom one can be used by implementing the 41 | // Metrics interface. 42 | Metrics Metrics 43 | 44 | // RefreshInterval is the interval at which the objects are refreshed, 45 | // e.g. how often Update() should be called to refresh an object even 46 | // when it has not changed. This is implemented by periodically setting 47 | // all objects that have not been updated for [RefreshInterval] or longer 48 | // as pending. 49 | // If set to 0 refreshing is disabled. 50 | RefreshInterval time.Duration 51 | 52 | // RefreshRateLimiter is optional and if set is used to limit the rate at 53 | // which objects are marked for refresh. If not provided a default rate 54 | // limiter is used. 55 | RefreshRateLimiter *rate.Limiter 56 | 57 | // PruneInterval is the interval at which Prune() is called to prune 58 | // unexpected objects in the target system. If set to 0 pruning is disabled. 59 | // Prune() will not be called before the table has been fully initialized 60 | // (Initialized() returns true). 61 | // A single Prune() can be forced via the [Reconciler.Prune] method regardless 62 | // of this value (0 or not). 63 | PruneInterval time.Duration 64 | 65 | // RetryBackoffMinDuration is the minimum amount of time to wait before 66 | // retrying a failed Update() or Delete() operation on an object. 67 | // The retry wait time for an object will increase exponentially on 68 | // subsequent failures until RetryBackoffMaxDuration is reached. 69 | RetryBackoffMinDuration time.Duration 70 | 71 | // RetryBackoffMaxDuration is the maximum amount of time to wait before 72 | // retrying. 73 | RetryBackoffMaxDuration time.Duration 74 | 75 | // IncrementalRoundSize is the maximum number objects to reconcile during 76 | // incremental reconciliation before updating status and refreshing the 77 | // statedb snapshot. This should be tuned based on the cost of each operation 78 | // and the rate of expected changes so that health and per-object status 79 | // updates are not delayed too much. If in doubt, use a value between 100-1000. 80 | IncrementalRoundSize int 81 | 82 | // RateLimiter is optional and if set will use the limiter to wait between 83 | // reconciliation rounds. This allows trading latency with throughput by 84 | // waiting longer to collect a batch of objects to reconcile. 85 | RateLimiter *rate.Limiter 86 | } 87 | 88 | type config[Obj any] struct { 89 | // Table to reconcile. Mandatory. 90 | Table statedb.RWTable[Obj] 91 | 92 | // GetObjectStatus returns the reconciliation status for the object. 93 | // Mandatory. 94 | GetObjectStatus func(Obj) Status 95 | 96 | // SetObjectStatus sets the reconciliation status for the object. 97 | // This is called with a copy of the object returned by CloneObject. 98 | // Mandatory. 99 | SetObjectStatus func(Obj, Status) Obj 100 | 101 | // CloneObject returns a shallow copy of the object. This is used to 102 | // make it possible for the reconciliation operations to mutate 103 | // the object (to for example provide additional information that the 104 | // reconciliation produces) and to be able to set the reconciliation 105 | // status after the reconciliation. 106 | // Mandatory. 107 | CloneObject func(Obj) Obj 108 | 109 | // Operations defines how an object is reconciled. Mandatory. 110 | Operations Operations[Obj] 111 | 112 | // BatchOperations is optional and if provided these are used instead of 113 | // the single-object operations. 114 | BatchOperations BatchOperations[Obj] 115 | 116 | options 117 | } 118 | 119 | func (cfg config[Obj]) validate() error { 120 | if cfg.Table == nil { 121 | return fmt.Errorf("%T.Table cannot be nil", cfg) 122 | } 123 | if cfg.GetObjectStatus == nil { 124 | return fmt.Errorf("%T.GetObjectStatus cannot be nil", cfg) 125 | } 126 | if cfg.SetObjectStatus == nil { 127 | return fmt.Errorf("%T.SetObjectStatus cannot be nil", cfg) 128 | } 129 | if cfg.CloneObject == nil { 130 | return fmt.Errorf("%T.CloneObject cannot be nil", cfg) 131 | } 132 | if cfg.IncrementalRoundSize <= 0 { 133 | return fmt.Errorf("%T.IncrementalBatchSize needs to be >0", cfg) 134 | } 135 | if cfg.RefreshInterval < 0 { 136 | return fmt.Errorf("%T.RefreshInterval must be >=0", cfg) 137 | } 138 | if cfg.PruneInterval < 0 { 139 | return fmt.Errorf("%T.PruneInterval must be >=0", cfg) 140 | } 141 | if cfg.RetryBackoffMaxDuration <= 0 { 142 | return fmt.Errorf("%T.RetryBackoffMaxDuration must be >0", cfg) 143 | } 144 | if cfg.RetryBackoffMinDuration <= 0 { 145 | return fmt.Errorf("%T.RetryBackoffMinDuration must be >0", cfg) 146 | } 147 | if cfg.Operations == nil { 148 | return fmt.Errorf("%T.Operations must be defined", cfg) 149 | } 150 | return nil 151 | } 152 | -------------------------------------------------------------------------------- /reconciler/example/.gitignore: -------------------------------------------------------------------------------- 1 | example 2 | -------------------------------------------------------------------------------- /reconciler/example/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package main 5 | 6 | import ( 7 | "expvar" 8 | "fmt" 9 | "io" 10 | "log/slog" 11 | "net/http" 12 | "os" 13 | "strings" 14 | "time" 15 | 16 | "github.com/spf13/cobra" 17 | "golang.org/x/time/rate" 18 | 19 | "github.com/cilium/hive" 20 | "github.com/cilium/hive/cell" 21 | "github.com/cilium/hive/job" 22 | "github.com/cilium/statedb" 23 | "github.com/cilium/statedb/reconciler" 24 | ) 25 | 26 | // This is a simple example of the statedb reconciler. It implements an 27 | // HTTP API for creating and deleting "memos" that are stored on the 28 | // disk. 29 | // 30 | // To run the application: 31 | // 32 | // $ go run . 33 | // (ctrl-c to stop) 34 | // 35 | // To create a memo: 36 | // 37 | // $ curl -d 'hello world' http://localhost:8080/memos/greeting 38 | // $ cat memos/greeting 39 | // 40 | // To delete a memo: 41 | // 42 | // $ curl -XDELETE http://localhost:8080/memos/greeting 43 | // 44 | // The application builds on top of the reconciler which retries any failed 45 | // operations and also does periodic "full reconciliation" to prune unknown 46 | // memos and check that the stored memos are up-to-date. To test the resilence 47 | // you can try out the following: 48 | // 49 | // # Create 'memos/greeting' 50 | // $ curl -d 'hello world' http://localhost:8080/memos/greeting 51 | // 52 | // # Make the file read-only and try changing it: 53 | // $ chmod a-w memos/greeting 54 | // $ curl -d 'hei maailma' http://localhost:8080/memos/greeting 55 | // # (You should now see the update operation hitting permission denied) 56 | // 57 | // # The reconciliation state can be observed in the Table[*Memo]: 58 | // $ curl -q http://localhost:8080/statedb | jq . 59 | // 60 | // # Let's give write permissions back: 61 | // $ chmod u+w memos/greeting 62 | // # (The update operation should now succeed) 63 | // $ cat memos/greeting 64 | // $ curl -s http://localhost:8080/statedb | jq . 65 | // 66 | // # The full reconciliation runs every 10 seconds. We can see it in 67 | // # action by modifying the contents of our greeting or by creating 68 | // # a file directly: 69 | // $ echo bogus > memos/bogus 70 | // $ echo mangled > memos/greeting 71 | // # (wait up to 10 seconds) 72 | // $ cat memos/bogus 73 | // $ cat memos/greeting 74 | // 75 | 76 | func main() { 77 | cmd := cobra.Command{ 78 | Use: "example", 79 | Run: func(_ *cobra.Command, args []string) { 80 | if err := Hive.Run(slog.Default()); err != nil { 81 | fmt.Fprintf(os.Stderr, "Run: %s\n", err) 82 | } 83 | }, 84 | } 85 | 86 | // Register command-line flags. Currently only 87 | // has --directory for specifying where to store 88 | // the memos. 89 | Hive.RegisterFlags(cmd.Flags()) 90 | 91 | // Add the "hive" command for inspecting the object graph: 92 | // 93 | // $ go run . hive 94 | // 95 | cmd.AddCommand(Hive.Command()) 96 | 97 | cmd.Execute() 98 | } 99 | 100 | var Hive = hive.NewWithOptions( 101 | hive.Options{ 102 | // Create a named DB handle for each module. 103 | ModuleDecorators: []cell.ModuleDecorator{ 104 | func(db *statedb.DB, id cell.ModuleID) *statedb.DB { 105 | return db.NewHandle(string(id)) 106 | }, 107 | }, 108 | }, 109 | 110 | statedb.Cell, 111 | job.Cell, 112 | 113 | cell.SimpleHealthCell, 114 | 115 | cell.Provide(reconciler.NewExpVarMetrics), 116 | 117 | cell.Module( 118 | "example", 119 | "Reconciler example", 120 | 121 | cell.Config(Config{}), 122 | 123 | cell.Provide( 124 | // Create and register the RWTable[*Memo] 125 | NewMemoTable, 126 | 127 | // Provide the Operations[*Memo] for reconciling Memos. 128 | NewMemoOps, 129 | ), 130 | 131 | // Create and register the reconciler for memos. 132 | // The reconciler watches Table[*Memo] for changes and 133 | // updates the memo files on disk accordingly. 134 | cell.Invoke(registerMemoReconciler), 135 | 136 | cell.Invoke(registerHTTPServer), 137 | ), 138 | ) 139 | 140 | func registerMemoReconciler( 141 | params reconciler.Params, 142 | ops reconciler.Operations[*Memo], 143 | tbl statedb.RWTable[*Memo], 144 | m *reconciler.ExpVarMetrics) error { 145 | 146 | // Create a new reconciler and register it to the lifecycle. 147 | // We ignore the returned Reconciler[*Memo] as we don't use it. 148 | _, err := reconciler.Register( 149 | params, 150 | tbl, 151 | (*Memo).Clone, 152 | (*Memo).SetStatus, 153 | (*Memo).GetStatus, 154 | ops, 155 | nil, // no batch operations support 156 | 157 | reconciler.WithMetrics(m), 158 | // Prune unexpected memos from disk once a minute. 159 | reconciler.WithPruning(time.Minute), 160 | // Refresh the memos once a minute. 161 | reconciler.WithRefreshing(time.Minute, rate.NewLimiter(100.0, 1)), 162 | ) 163 | return err 164 | } 165 | 166 | func registerHTTPServer( 167 | lc cell.Lifecycle, 168 | log *slog.Logger, 169 | db *statedb.DB, 170 | memos statedb.RWTable[*Memo]) { 171 | 172 | mux := http.NewServeMux() 173 | 174 | // To dump the metrics: 175 | // curl -s http://localhost:8080/expvar 176 | mux.Handle("/expvar", expvar.Handler()) 177 | 178 | // For dumping the database: 179 | // curl -s http://localhost:8080/statedb | jq . 180 | mux.HandleFunc("/statedb", func(w http.ResponseWriter, r *http.Request) { 181 | w.Header().Add("Content-Type", "application/json") 182 | w.WriteHeader(http.StatusOK) 183 | if err := db.ReadTxn().WriteJSON(w); err != nil { 184 | panic(err) 185 | } 186 | }) 187 | 188 | // For creating and deleting memos: 189 | // curl -d 'foo' http://localhost:8080/memos/bar 190 | // curl -XDELETE http://localhost:8080/memos/bar 191 | mux.HandleFunc("/memos/", func(w http.ResponseWriter, r *http.Request) { 192 | name, ok := strings.CutPrefix(r.URL.Path, "/memos/") 193 | if !ok { 194 | w.WriteHeader(http.StatusBadRequest) 195 | return 196 | } 197 | 198 | txn := db.WriteTxn(memos) 199 | defer txn.Commit() 200 | 201 | switch r.Method { 202 | case "POST": 203 | content, err := io.ReadAll(r.Body) 204 | if err != nil { 205 | return 206 | } 207 | memos.Insert( 208 | txn, 209 | &Memo{ 210 | Name: name, 211 | Content: string(content), 212 | Status: reconciler.StatusPending(), 213 | }) 214 | log.Info("Inserted memo", "name", name) 215 | w.WriteHeader(http.StatusOK) 216 | 217 | case "DELETE": 218 | memo, _, ok := memos.Get(txn, MemoNameIndex.Query(name)) 219 | if !ok { 220 | w.WriteHeader(http.StatusNotFound) 221 | return 222 | } 223 | memos.Delete(txn, memo) 224 | log.Info("Deleted memo", "name", name) 225 | w.WriteHeader(http.StatusOK) 226 | } 227 | }) 228 | 229 | server := http.Server{ 230 | Addr: "127.0.0.1:8080", 231 | Handler: mux, 232 | } 233 | 234 | lc.Append(cell.Hook{ 235 | OnStart: func(cell.HookContext) error { 236 | log.Info("Serving API", "address", server.Addr) 237 | go server.ListenAndServe() 238 | return nil 239 | }, 240 | OnStop: func(ctx cell.HookContext) error { 241 | return server.Shutdown(ctx) 242 | }, 243 | }) 244 | 245 | } 246 | -------------------------------------------------------------------------------- /reconciler/example/ops.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package main 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "iter" 10 | "log/slog" 11 | "os" 12 | "path" 13 | 14 | "github.com/cilium/hive/cell" 15 | "github.com/cilium/statedb" 16 | "github.com/cilium/statedb/reconciler" 17 | ) 18 | 19 | // MemoOps writes [Memo]s to disk. 20 | // Implements the Reconciler.Operations[*Memo] API. 21 | type MemoOps struct { 22 | log *slog.Logger 23 | directory string 24 | } 25 | 26 | // NewMemoOps creates the memo operations. 27 | func NewMemoOps(lc cell.Lifecycle, log *slog.Logger, cfg Config) reconciler.Operations[*Memo] { 28 | ops := &MemoOps{directory: cfg.Directory, log: log} 29 | 30 | // Register the Start and Stop methods to be called when the application 31 | // starts and stops respectively. The start hook will create the 32 | // memo directory. 33 | lc.Append(ops) 34 | return ops 35 | } 36 | 37 | // Delete a memo. 38 | func (ops *MemoOps) Delete(ctx context.Context, txn statedb.ReadTxn, rev statedb.Revision, memo *Memo) error { 39 | filename := path.Join(ops.directory, memo.Name) 40 | err := os.Remove(filename) 41 | ops.log.Info("Delete", "filename", filename, "error", err) 42 | return err 43 | } 44 | 45 | // Prune unexpected memos. 46 | func (ops *MemoOps) Prune(ctx context.Context, txn statedb.ReadTxn, objects iter.Seq2[*Memo, statedb.Revision]) error { 47 | expected := map[string]struct{}{} 48 | 49 | for memo := range objects { 50 | expected[memo.Name] = struct{}{} 51 | } 52 | 53 | // Find unexpected files 54 | unexpected := map[string]struct{}{} 55 | if entries, err := os.ReadDir(ops.directory); err != nil { 56 | return err 57 | } else { 58 | for _, entry := range entries { 59 | if _, ok := expected[entry.Name()]; !ok { 60 | unexpected[entry.Name()] = struct{}{} 61 | } 62 | } 63 | } 64 | 65 | // ... and remove them. 66 | var errs []error 67 | for name := range unexpected { 68 | filename := path.Join(ops.directory, name) 69 | err := os.Remove(filename) 70 | ops.log.Info("Prune", "filename", filename, "error", err) 71 | if err != nil { 72 | errs = append(errs, err) 73 | } 74 | } 75 | return errors.Join(errs...) 76 | } 77 | 78 | // Update a memo. 79 | func (ops *MemoOps) Update(ctx context.Context, txn statedb.ReadTxn, rev statedb.Revision, memo *Memo) error { 80 | filename := path.Join(ops.directory, memo.Name) 81 | err := os.WriteFile(filename, []byte(memo.Content), 0644) 82 | ops.log.Info("Update", "filename", filename, "error", err) 83 | return err 84 | } 85 | 86 | var _ reconciler.Operations[*Memo] = &MemoOps{} 87 | 88 | func (ops *MemoOps) Start(cell.HookContext) error { 89 | return os.MkdirAll(ops.directory, 0755) 90 | } 91 | 92 | func (*MemoOps) Stop(cell.HookContext) error { 93 | return nil 94 | } 95 | 96 | var _ cell.HookInterface = &MemoOps{} 97 | -------------------------------------------------------------------------------- /reconciler/example/types.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package main 5 | 6 | import ( 7 | "github.com/spf13/pflag" 8 | 9 | "github.com/cilium/statedb" 10 | "github.com/cilium/statedb/index" 11 | "github.com/cilium/statedb/reconciler" 12 | ) 13 | 14 | // Config defines the command-line configuration for the memos 15 | // example application. 16 | type Config struct { 17 | Directory string // the directory in which memos are stored. 18 | } 19 | 20 | func (def Config) Flags(flags *pflag.FlagSet) { 21 | flags.String("directory", "memos", "Memo directory") 22 | } 23 | 24 | // Memo is a brief note stored in the memos directory. A memo 25 | // can be created with the /memos API. 26 | type Memo struct { 27 | Name string // filename of the memo. Stored in /. 28 | Content string // contents of the memo. 29 | Status reconciler.Status // reconciliation status 30 | } 31 | 32 | // GetStatus returns the reconciliation status. Used to provide the 33 | // reconciler access to it. 34 | func (memo *Memo) GetStatus() reconciler.Status { 35 | return memo.Status 36 | } 37 | 38 | // SetStatus sets the reconciliation status. 39 | // Used by the reconciler to update the reconciliation status of the memo. 40 | func (memo *Memo) SetStatus(newStatus reconciler.Status) *Memo { 41 | memo.Status = newStatus 42 | return memo 43 | } 44 | 45 | // Clone returns a shallow copy of the memo. 46 | func (memo *Memo) Clone() *Memo { 47 | m := *memo 48 | return &m 49 | } 50 | 51 | // MemoNameIndex allows looking up the memo by its name, e.g. 52 | // memos.First(txn, MemoNameIndex.Query("my-memo")) 53 | var MemoNameIndex = statedb.Index[*Memo, string]{ 54 | Name: "name", 55 | FromObject: func(memo *Memo) index.KeySet { 56 | return index.NewKeySet(index.String(memo.Name)) 57 | }, 58 | FromKey: index.String, 59 | Unique: true, 60 | } 61 | 62 | // MemoStatusIndex indexes memos by their reconciliation status. 63 | // This is mainly used by the reconciler to implement WaitForReconciliation. 64 | var MemoStatusIndex = reconciler.NewStatusIndex((*Memo).GetStatus) 65 | 66 | // NewMemoTable creates and registers the memos table. 67 | func NewMemoTable(db *statedb.DB) (statedb.RWTable[*Memo], statedb.Index[*Memo, reconciler.StatusKind], error) { 68 | tbl, err := statedb.NewTable( 69 | "memos", 70 | MemoNameIndex, 71 | MemoStatusIndex, 72 | ) 73 | if err == nil { 74 | err = db.RegisterTable(tbl) 75 | } 76 | return tbl, MemoStatusIndex, err 77 | } 78 | -------------------------------------------------------------------------------- /reconciler/helpers.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package reconciler 5 | 6 | import ( 7 | "errors" 8 | "fmt" 9 | "slices" 10 | ) 11 | 12 | var closedWatchChannel = func() <-chan struct{} { 13 | ch := make(chan struct{}) 14 | close(ch) 15 | return ch 16 | }() 17 | 18 | const ( 19 | // maxJoinedErrors limits the number of errors to join and return from 20 | // failed reconciliation. This avoids constructing a massive error for 21 | // health status when many operations fail at once. 22 | maxJoinedErrors = 10 23 | ) 24 | 25 | func omittedError(n int) error { 26 | return fmt.Errorf("%d further errors omitted", n) 27 | } 28 | 29 | func joinErrors(errs []error) error { 30 | if len(errs) > maxJoinedErrors { 31 | errs = append(slices.Clone(errs)[:maxJoinedErrors], omittedError(len(errs)-maxJoinedErrors)) 32 | } 33 | return errors.Join(errs...) 34 | } 35 | -------------------------------------------------------------------------------- /reconciler/index.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package reconciler 5 | 6 | import ( 7 | "context" 8 | 9 | "github.com/cilium/statedb" 10 | "github.com/cilium/statedb/index" 11 | ) 12 | 13 | // NewStatusIndex creates a status index for a table of reconcilable objects. 14 | // This is optional and should be only used when there is a need to often check that all 15 | // objects are fully reconciled that outweighs the cost of maintaining a status index. 16 | func NewStatusIndex[Obj any](getObjectStatus func(Obj) Status) statedb.Index[Obj, StatusKind] { 17 | return statedb.Index[Obj, StatusKind]{ 18 | Name: "status", 19 | FromObject: func(obj Obj) index.KeySet { 20 | return index.NewKeySet(getObjectStatus(obj).Kind.Key()) 21 | }, 22 | FromKey: StatusKind.Key, 23 | Unique: false, 24 | } 25 | } 26 | 27 | // WaitForReconciliation blocks until all objects have been reconciled or the context 28 | // has cancelled. 29 | func WaitForReconciliation[Obj any](ctx context.Context, db *statedb.DB, table statedb.Table[Obj], statusIndex statedb.Index[Obj, StatusKind]) error { 30 | for { 31 | txn := db.ReadTxn() 32 | 33 | // See if there are any pending or error'd objects. 34 | _, _, watchPending, okPending := table.GetWatch(txn, statusIndex.Query(StatusKindPending)) 35 | _, _, watchError, okError := table.GetWatch(txn, statusIndex.Query(StatusKindError)) 36 | if !okPending && !okError { 37 | return nil 38 | } 39 | 40 | // Wait for updates before checking again. 41 | select { 42 | case <-ctx.Done(): 43 | return ctx.Err() 44 | case <-watchPending: 45 | case <-watchError: 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /reconciler/metrics.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package reconciler 5 | 6 | import ( 7 | "expvar" 8 | "time" 9 | 10 | "github.com/cilium/hive/cell" 11 | ) 12 | 13 | type Metrics interface { 14 | ReconciliationDuration(moduleID cell.FullModuleID, operation string, duration time.Duration) 15 | ReconciliationErrors(moduleID cell.FullModuleID, new, current int) 16 | 17 | PruneError(moduleID cell.FullModuleID, err error) 18 | PruneDuration(moduleID cell.FullModuleID, duration time.Duration) 19 | } 20 | 21 | const ( 22 | OpUpdate = "update" 23 | OpDelete = "delete" 24 | ) 25 | 26 | type ExpVarMetrics struct { 27 | root *expvar.Map 28 | 29 | ReconciliationCountVar *expvar.Map 30 | ReconciliationDurationVar *expvar.Map 31 | ReconciliationTotalErrorsVar *expvar.Map 32 | ReconciliationCurrentErrorsVar *expvar.Map 33 | 34 | PruneCountVar *expvar.Map 35 | PruneDurationVar *expvar.Map 36 | PruneTotalErrorsVar *expvar.Map 37 | PruneCurrentErrorsVar *expvar.Map 38 | } 39 | 40 | func (m *ExpVarMetrics) PruneDuration(moduleID cell.FullModuleID, duration time.Duration) { 41 | m.PruneDurationVar.AddFloat(moduleID.String(), duration.Seconds()) 42 | } 43 | 44 | func (m *ExpVarMetrics) PruneError(moduleID cell.FullModuleID, err error) { 45 | m.PruneCountVar.Add(moduleID.String(), 1) 46 | 47 | var intVar expvar.Int 48 | if err != nil { 49 | m.PruneTotalErrorsVar.Add(moduleID.String(), 1) 50 | intVar.Set(1) 51 | } 52 | m.PruneCurrentErrorsVar.Set(moduleID.String(), &intVar) 53 | } 54 | 55 | func (m *ExpVarMetrics) ReconciliationDuration(moduleID cell.FullModuleID, operation string, duration time.Duration) { 56 | m.ReconciliationDurationVar.AddFloat(moduleID.String()+"/"+operation, duration.Seconds()) 57 | } 58 | 59 | func (m *ExpVarMetrics) ReconciliationErrors(moduleID cell.FullModuleID, new, current int) { 60 | m.ReconciliationCountVar.Add(moduleID.String(), 1) 61 | m.ReconciliationTotalErrorsVar.Add(moduleID.String(), int64(new)) 62 | 63 | var intVar expvar.Int 64 | intVar.Set(int64(current)) 65 | m.ReconciliationCurrentErrorsVar.Set(moduleID.String(), &intVar) 66 | } 67 | 68 | var _ Metrics = &ExpVarMetrics{} 69 | 70 | func NewExpVarMetrics() *ExpVarMetrics { 71 | return newExpVarMetrics(true) 72 | } 73 | 74 | func NewUnpublishedExpVarMetrics() *ExpVarMetrics { 75 | return newExpVarMetrics(false) 76 | } 77 | 78 | func (m *ExpVarMetrics) Map() *expvar.Map { 79 | return m.root 80 | } 81 | 82 | func newExpVarMetrics(publish bool) *ExpVarMetrics { 83 | root := new(expvar.Map).Init() 84 | newMap := func(name string) *expvar.Map { 85 | if publish { 86 | return expvar.NewMap(name) 87 | } 88 | m := new(expvar.Map).Init() 89 | root.Set(name, m) 90 | return m 91 | } 92 | return &ExpVarMetrics{ 93 | root: root, 94 | ReconciliationCountVar: newMap("reconciliation_count"), 95 | ReconciliationDurationVar: newMap("reconciliation_duration"), 96 | ReconciliationTotalErrorsVar: newMap("reconciliation_total_errors"), 97 | ReconciliationCurrentErrorsVar: newMap("reconciliation_current_errors"), 98 | PruneCountVar: newMap("prune_count"), 99 | PruneDurationVar: newMap("prune_duration"), 100 | PruneTotalErrorsVar: newMap("prune_total_errors"), 101 | PruneCurrentErrorsVar: newMap("prune_current_errors"), 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /reconciler/multi_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package reconciler_test 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "iter" 10 | "log/slog" 11 | "sync/atomic" 12 | "testing" 13 | "time" 14 | 15 | "github.com/cilium/hive" 16 | "github.com/cilium/hive/cell" 17 | "github.com/cilium/hive/hivetest" 18 | "github.com/cilium/hive/job" 19 | "github.com/cilium/statedb" 20 | "github.com/cilium/statedb/index" 21 | "github.com/cilium/statedb/reconciler" 22 | "github.com/stretchr/testify/assert" 23 | "github.com/stretchr/testify/require" 24 | ) 25 | 26 | type multiStatusObject struct { 27 | ID uint64 28 | Statuses reconciler.StatusSet 29 | } 30 | 31 | func (m *multiStatusObject) Clone() *multiStatusObject { 32 | m2 := *m 33 | return &m2 34 | } 35 | 36 | var multiStatusIndex = statedb.Index[*multiStatusObject, uint64]{ 37 | Name: "id", 38 | FromObject: func(t *multiStatusObject) index.KeySet { 39 | return index.NewKeySet(index.Uint64(t.ID)) 40 | }, 41 | FromKey: index.Uint64, 42 | Unique: true, 43 | } 44 | 45 | type multiMockOps struct { 46 | numUpdates int 47 | faulty atomic.Bool 48 | } 49 | 50 | // Delete implements reconciler.Operations. 51 | func (m *multiMockOps) Delete(context.Context, statedb.ReadTxn, statedb.Revision, *multiStatusObject) error { 52 | return nil 53 | } 54 | 55 | // Prune implements reconciler.Operations. 56 | func (m *multiMockOps) Prune(context.Context, statedb.ReadTxn, iter.Seq2[*multiStatusObject, statedb.Revision]) error { 57 | return nil 58 | } 59 | 60 | // Update implements reconciler.Operations. 61 | func (m *multiMockOps) Update(ctx context.Context, txn statedb.ReadTxn, rev statedb.Revision, obj *multiStatusObject) error { 62 | m.numUpdates++ 63 | if m.faulty.Load() { 64 | return errors.New("fail") 65 | } 66 | return nil 67 | } 68 | 69 | var _ reconciler.Operations[*multiStatusObject] = &multiMockOps{} 70 | 71 | // TestMultipleReconcilers tests use of multiple reconcilers against 72 | // a single object. 73 | func TestMultipleReconcilers(t *testing.T) { 74 | table, err := statedb.NewTable("objects", multiStatusIndex) 75 | require.NoError(t, err, "NewTable") 76 | 77 | var ops1, ops2 multiMockOps 78 | var db *statedb.DB 79 | 80 | hive := hive.New( 81 | statedb.Cell, 82 | job.Cell, 83 | cell.Provide( 84 | cell.NewSimpleHealth, 85 | reconciler.NewExpVarMetrics, 86 | func(r job.Registry, h cell.Health, lc cell.Lifecycle) job.Group { 87 | return r.NewGroup(h, lc) 88 | }, 89 | ), 90 | cell.Invoke(func(db_ *statedb.DB) error { 91 | db = db_ 92 | return db.RegisterTable(table) 93 | }), 94 | 95 | cell.Module("test1", "First reconciler", 96 | cell.Invoke(func(params reconciler.Params) error { 97 | _, err := reconciler.Register( 98 | params, 99 | table, 100 | (*multiStatusObject).Clone, 101 | func(obj *multiStatusObject, s reconciler.Status) *multiStatusObject { 102 | obj.Statuses = obj.Statuses.Set("test1", s) 103 | return obj 104 | }, 105 | func(obj *multiStatusObject) reconciler.Status { 106 | return obj.Statuses.Get("test1") 107 | }, 108 | &ops1, 109 | nil, 110 | reconciler.WithRetry(time.Hour, time.Hour), 111 | ) 112 | return err 113 | }), 114 | ), 115 | 116 | cell.Module("test2", "Second reconciler", 117 | cell.Invoke(func(params reconciler.Params) error { 118 | _, err := reconciler.Register( 119 | params, 120 | table, 121 | (*multiStatusObject).Clone, 122 | func(obj *multiStatusObject, s reconciler.Status) *multiStatusObject { 123 | obj.Statuses = obj.Statuses.Set("test2", s) 124 | return obj 125 | }, 126 | func(obj *multiStatusObject) reconciler.Status { 127 | return obj.Statuses.Get("test2") 128 | }, 129 | &ops2, 130 | nil, 131 | reconciler.WithRetry(time.Hour, time.Hour), 132 | ) 133 | return err 134 | }), 135 | ), 136 | ) 137 | 138 | log := hivetest.Logger(t, hivetest.LogLevel(slog.LevelError)) 139 | require.NoError(t, hive.Start(log, context.TODO()), "Start") 140 | 141 | wtxn := db.WriteTxn(table) 142 | table.Insert(wtxn, &multiStatusObject{ 143 | ID: 1, 144 | Statuses: reconciler.NewStatusSet(), 145 | }) 146 | wtxn.Commit() 147 | 148 | var obj1 *multiStatusObject 149 | for { 150 | obj, _, watch, found := table.GetWatch(db.ReadTxn(), multiStatusIndex.Query(1)) 151 | if found && 152 | obj.Statuses.Get("test1").Kind == reconciler.StatusKindDone && 153 | obj.Statuses.Get("test2").Kind == reconciler.StatusKindDone { 154 | 155 | // Check that both reconcilers performed the update only once. 156 | assert.Equal(t, 1, ops1.numUpdates) 157 | assert.Equal(t, 1, ops2.numUpdates) 158 | assert.Regexp(t, "^Done: test[12] test[12] \\(.* ago\\)", obj.Statuses.String()) 159 | 160 | obj1 = obj 161 | break 162 | } 163 | <-watch 164 | } 165 | 166 | // Make the second reconciler faulty. 167 | ops2.faulty.Store(true) 168 | 169 | // Mark the object pending again. Reuse the StatusSet. 170 | wtxn = db.WriteTxn(table) 171 | obj1 = obj1.Clone() 172 | obj1.Statuses = obj1.Statuses.Pending() 173 | assert.Regexp(t, "^Pending: test[12] test[12] \\(.* ago\\)", obj1.Statuses.String()) 174 | table.Insert(wtxn, obj1) 175 | wtxn.Commit() 176 | 177 | // Wait for it to reconcile. 178 | for { 179 | obj, _, watch, found := table.GetWatch(db.ReadTxn(), multiStatusIndex.Query(1)) 180 | if found && 181 | obj.Statuses.Get("test1").Kind == reconciler.StatusKindDone && 182 | obj.Statuses.Get("test2").Kind == reconciler.StatusKindError { 183 | 184 | assert.Equal(t, 2, ops1.numUpdates) 185 | assert.Equal(t, 2, ops2.numUpdates) 186 | assert.Regexp(t, "^Errored: test2 \\(fail\\), Done: test1 \\(.* ago\\)", obj.Statuses.String()) 187 | 188 | break 189 | } 190 | <-watch 191 | } 192 | 193 | require.NoError(t, hive.Stop(log, context.TODO()), "Stop") 194 | } 195 | -------------------------------------------------------------------------------- /reconciler/reconciler.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package reconciler 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "iter" 10 | "time" 11 | 12 | "github.com/cilium/hive/cell" 13 | "github.com/cilium/statedb" 14 | ) 15 | 16 | type reconciler[Obj comparable] struct { 17 | Params 18 | config config[Obj] 19 | retries *retries 20 | externalPruneTrigger chan struct{} 21 | primaryIndexer statedb.Indexer[Obj] 22 | } 23 | 24 | func (r *reconciler[Obj]) Prune() { 25 | select { 26 | case r.externalPruneTrigger <- struct{}{}: 27 | default: 28 | } 29 | } 30 | 31 | func (r *reconciler[Obj]) reconcileLoop(ctx context.Context, health cell.Health) error { 32 | var pruneTickerChan <-chan time.Time 33 | if r.config.PruneInterval > 0 { 34 | pruneTicker := time.NewTicker(r.config.PruneInterval) 35 | defer pruneTicker.Stop() 36 | pruneTickerChan = pruneTicker.C 37 | } 38 | 39 | // Create the change iterator to watch for inserts and deletes to the table. 40 | wtxn := r.DB.WriteTxn(r.config.Table) 41 | changeIterator, err := r.config.Table.Changes(wtxn) 42 | txn := wtxn.Commit() 43 | if err != nil { 44 | return fmt.Errorf("watching for changes failed: %w", err) 45 | } 46 | 47 | tableWatchChan := closedWatchChannel 48 | 49 | externalPrune := false 50 | 51 | tableInitialized := false 52 | _, tableInitWatch := r.config.Table.Initialized(txn) 53 | 54 | for { 55 | // Throttle a bit before reconciliation to allow for a bigger batch to arrive and 56 | // for objects to settle. 57 | if err := r.config.RateLimiter.Wait(ctx); err != nil { 58 | return err 59 | } 60 | 61 | prune := false 62 | 63 | // Wait for trigger 64 | select { 65 | case <-ctx.Done(): 66 | return ctx.Err() 67 | case <-r.retries.Wait(): 68 | // Object(s) are ready to be retried 69 | case <-tableWatchChan: 70 | // Table has changed 71 | case <-tableInitWatch: 72 | tableInitialized = true 73 | tableInitWatch = nil 74 | 75 | // Do an immediate pruning now as the table has finished 76 | // initializing and pruning is enabled. 77 | prune = r.config.PruneInterval != 0 78 | case <-pruneTickerChan: 79 | prune = true 80 | case <-r.externalPruneTrigger: 81 | externalPrune = true 82 | } 83 | 84 | // Grab a new snapshot and refresh the changes iterator to read 85 | // in the new changes. 86 | txn = r.DB.ReadTxn() 87 | var changes iter.Seq2[statedb.Change[Obj], statedb.Revision] 88 | changes, tableWatchChan = changeIterator.Next(txn) 89 | 90 | // Perform incremental reconciliation and retries of previously failed 91 | // objects. 92 | errs := r.incremental(ctx, txn, changes) 93 | 94 | if tableInitialized && (prune || externalPrune) { 95 | if err := r.prune(ctx, txn); err != nil { 96 | errs = append(errs, err) 97 | } 98 | externalPrune = false 99 | } 100 | 101 | if len(errs) == 0 { 102 | health.OK( 103 | fmt.Sprintf("OK, %d object(s)", r.config.Table.NumObjects(txn))) 104 | } else { 105 | health.Degraded( 106 | fmt.Sprintf("%d error(s)", len(errs)), 107 | joinErrors(errs)) 108 | } 109 | } 110 | } 111 | 112 | // prune performs the Prune operation to delete unexpected objects in the target system. 113 | func (r *reconciler[Obj]) prune(ctx context.Context, txn statedb.ReadTxn) error { 114 | iter := r.config.Table.All(txn) 115 | start := time.Now() 116 | err := r.config.Operations.Prune(ctx, txn, iter) 117 | if err != nil { 118 | r.Log.Warn("Reconciler: failed to prune objects", "error", err, "pruneInterval", r.config.PruneInterval) 119 | err = fmt.Errorf("prune: %w", err) 120 | } 121 | r.config.Metrics.PruneDuration(r.ModuleID, time.Since(start)) 122 | r.config.Metrics.PruneError(r.ModuleID, err) 123 | return err 124 | } 125 | 126 | func (r *reconciler[Obj]) refreshLoop(ctx context.Context, health cell.Health) error { 127 | lastRevision := statedb.Revision(0) 128 | 129 | refreshTimer := time.NewTimer(0) 130 | defer refreshTimer.Stop() 131 | 132 | for { 133 | // Wait until it's time to refresh. 134 | select { 135 | case <-ctx.Done(): 136 | return nil 137 | 138 | case <-refreshTimer.C: 139 | } 140 | 141 | durationUntilRefresh := r.config.RefreshInterval 142 | 143 | // Iterate over the objects in revision order, e.g. oldest modification first. 144 | // We look for objects that are older than [RefreshInterval] and mark them for 145 | // refresh in order for them to be reconciled again. 146 | seq := r.config.Table.LowerBound(r.DB.ReadTxn(), statedb.ByRevision[Obj](lastRevision+1)) 147 | indexer := r.config.Table.PrimaryIndexer() 148 | 149 | for obj, rev := range seq { 150 | status := r.config.GetObjectStatus(obj) 151 | 152 | // The duration elapsed since this object was last updated. 153 | updatedSince := time.Since(status.UpdatedAt) 154 | 155 | // Have we reached an object that is newer than RefreshInterval? 156 | // If so, wait until this now oldest object's UpdatedAt exceeds RefreshInterval. 157 | if updatedSince < r.config.RefreshInterval { 158 | durationUntilRefresh = r.config.RefreshInterval - updatedSince 159 | break 160 | } 161 | 162 | lastRevision = rev 163 | 164 | if status.Kind == StatusKindDone { 165 | if r.config.RefreshRateLimiter != nil { 166 | // Limit the rate at which objects are marked for refresh to avoid disrupting 167 | // normal work. 168 | if err := r.config.RefreshRateLimiter.Wait(ctx); err != nil { 169 | break 170 | } 171 | } 172 | 173 | // Mark the object for refreshing. We make the assumption that refreshing is spread over 174 | // time enough that batching of the writes is not useful here. 175 | wtxn := r.DB.WriteTxn(r.config.Table) 176 | obj, newRev, ok := r.config.Table.Get(wtxn, indexer.QueryFromObject(obj)) 177 | if ok && rev == newRev { 178 | obj = r.config.SetObjectStatus(r.config.CloneObject(obj), StatusRefreshing()) 179 | r.config.Table.Insert(wtxn, obj) 180 | } 181 | wtxn.Commit() 182 | } 183 | } 184 | 185 | refreshTimer.Reset(durationUntilRefresh) 186 | health.OK(fmt.Sprintf("Next refresh in %s", durationUntilRefresh)) 187 | } 188 | } 189 | -------------------------------------------------------------------------------- /reconciler/retries.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package reconciler 5 | 6 | import ( 7 | "container/heap" 8 | "math" 9 | "time" 10 | 11 | "github.com/cilium/statedb" 12 | "github.com/cilium/statedb/index" 13 | ) 14 | 15 | type exponentialBackoff struct { 16 | min time.Duration 17 | max time.Duration 18 | } 19 | 20 | func (e *exponentialBackoff) Duration(attempt int) time.Duration { 21 | dur := float64(e.min) * math.Pow(2, float64(attempt)) 22 | if dur > float64(e.max) { 23 | return e.max 24 | } 25 | return time.Duration(dur) 26 | } 27 | 28 | func newRetries(minDuration, maxDuration time.Duration, objectToKey func(any) index.Key) *retries { 29 | return &retries{ 30 | backoff: exponentialBackoff{ 31 | min: minDuration, 32 | max: maxDuration, 33 | }, 34 | queue: nil, 35 | items: make(map[string]*retryItem), 36 | objectToKey: objectToKey, 37 | waitTimer: nil, 38 | waitChan: make(chan struct{}), 39 | } 40 | } 41 | 42 | // retries holds the items that failed to be reconciled in 43 | // a priority queue ordered by retry time. Methods of 'retries' 44 | // are not safe to access concurrently. 45 | type retries struct { 46 | backoff exponentialBackoff 47 | queue retryPrioQueue 48 | items map[string]*retryItem 49 | objectToKey func(any) index.Key 50 | waitTimer *time.Timer 51 | waitChan chan struct{} 52 | } 53 | 54 | func (rq *retries) errors() []error { 55 | errs := make([]error, 0, len(rq.items)) 56 | for _, item := range rq.items { 57 | errs = append(errs, item.lastError) 58 | } 59 | return errs 60 | } 61 | 62 | type retryItem struct { 63 | object any // the object that is being retried. 'any' to avoid specializing this internal code. 64 | rev statedb.Revision 65 | delete bool 66 | 67 | index int // item's index in the priority queue 68 | retryAt time.Time // time at which to retry 69 | numRetries int // number of retries attempted (for calculating backoff) 70 | lastError error 71 | } 72 | 73 | // Wait returns a channel that is closed when there is an item to retry. 74 | // Returns nil channel if no items are queued. 75 | func (rq *retries) Wait() <-chan struct{} { 76 | return rq.waitChan 77 | } 78 | 79 | func (rq *retries) Top() (*retryItem, bool) { 80 | if rq.queue.Len() == 0 { 81 | return nil, false 82 | } 83 | return rq.queue[0], true 84 | } 85 | 86 | func (rq *retries) Pop() { 87 | // Pop the object from the queue, but leave it into the map until 88 | // the object is cleared or re-added. 89 | rq.queue[0].index = -1 90 | heap.Pop(&rq.queue) 91 | 92 | rq.resetTimer() 93 | } 94 | 95 | func (rq *retries) resetTimer() { 96 | if rq.waitTimer == nil || !rq.waitTimer.Stop() { 97 | // Already fired so the channel was closed. Create a new one 98 | // channel and timer. 99 | waitChan := make(chan struct{}) 100 | rq.waitChan = waitChan 101 | if rq.queue.Len() == 0 { 102 | rq.waitTimer = nil 103 | } else { 104 | d := time.Until(rq.queue[0].retryAt) 105 | rq.waitTimer = time.AfterFunc(d, func() { close(waitChan) }) 106 | } 107 | } else if rq.queue.Len() > 0 { 108 | d := time.Until(rq.queue[0].retryAt) 109 | // Did not fire yet so we can just reset the timer. 110 | rq.waitTimer.Reset(d) 111 | } 112 | } 113 | 114 | func (rq *retries) Add(obj any, rev statedb.Revision, delete bool, lastError error) { 115 | var ( 116 | item *retryItem 117 | ok bool 118 | ) 119 | key := rq.objectToKey(obj) 120 | if item, ok = rq.items[string(key)]; !ok { 121 | item = &retryItem{ 122 | numRetries: 0, 123 | index: -1, 124 | } 125 | rq.items[string(key)] = item 126 | } 127 | item.object = obj 128 | item.rev = rev 129 | item.delete = delete 130 | item.numRetries += 1 131 | item.lastError = lastError 132 | duration := rq.backoff.Duration(item.numRetries) 133 | item.retryAt = time.Now().Add(duration) 134 | 135 | if item.index >= 0 { 136 | // The item was already in the queue, fix up its position. 137 | heap.Fix(&rq.queue, item.index) 138 | } else { 139 | heap.Push(&rq.queue, item) 140 | } 141 | 142 | // Item is at the head of the queue, reset the timer. 143 | if item.index == 0 { 144 | rq.resetTimer() 145 | } 146 | } 147 | 148 | func (rq *retries) Clear(obj any) { 149 | key := rq.objectToKey(obj) 150 | if item, ok := rq.items[string(key)]; ok { 151 | // Remove the object from the queue if it is still there. 152 | index := item.index // hold onto the index as heap.Remove messes with it 153 | if item.index >= 0 && item.index < len(rq.queue) && 154 | key.Equal(rq.objectToKey(rq.queue[item.index].object)) { 155 | heap.Remove(&rq.queue, item.index) 156 | 157 | // Reset the timer in case we removed the top item. 158 | if index == 0 { 159 | rq.resetTimer() 160 | } 161 | } 162 | // Completely forget the object and its retry count. 163 | delete(rq.items, string(key)) 164 | } 165 | } 166 | 167 | // retryPrioQueue is a slice-backed priority heap with the next 168 | // expiring 'retryItem' at top. Implementation is adapted from the 169 | // 'container/heap' PriorityQueue example. 170 | type retryPrioQueue []*retryItem 171 | 172 | func (pq retryPrioQueue) Len() int { return len(pq) } 173 | 174 | func (pq retryPrioQueue) Less(i, j int) bool { 175 | return pq[i].retryAt.Before(pq[j].retryAt) 176 | } 177 | 178 | func (pq retryPrioQueue) Swap(i, j int) { 179 | pq[i], pq[j] = pq[j], pq[i] 180 | pq[i].index = i 181 | pq[j].index = j 182 | } 183 | 184 | func (pq *retryPrioQueue) Push(x any) { 185 | retryObj := x.(*retryItem) 186 | retryObj.index = len(*pq) 187 | *pq = append(*pq, retryObj) 188 | } 189 | 190 | func (pq *retryPrioQueue) Pop() any { 191 | old := *pq 192 | n := len(old) 193 | item := old[n-1] 194 | old[n-1] = nil // avoid memory leak 195 | item.index = -1 // for safety 196 | *pq = old[0 : n-1] 197 | return item 198 | } 199 | -------------------------------------------------------------------------------- /reconciler/retries_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package reconciler 5 | 6 | import ( 7 | "errors" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | 14 | "github.com/cilium/statedb/index" 15 | ) 16 | 17 | func TestRetries(t *testing.T) { 18 | objectToKey := func(o any) index.Key { 19 | return index.Uint64(o.(uint64)) 20 | } 21 | rq := newRetries(time.Millisecond, 100*time.Millisecond, objectToKey) 22 | 23 | obj1, obj2, obj3 := uint64(1), uint64(2), uint64(3) 24 | 25 | // Add objects to be retried in order. We assume here that 'time.Time' has 26 | // enough granularity for these to be added with rising retryAt times. 27 | err := errors.New("some error") 28 | rq.Add(obj1, 1, false, err) 29 | rq.Add(obj2, 2, false, err) 30 | rq.Add(obj3, 3, false, err) 31 | 32 | errs := rq.errors() 33 | assert.Len(t, errs, 3) 34 | assert.Equal(t, err, errs[0]) 35 | 36 | // Adding an item a second time will increment the number of retries and 37 | // recalculate when it should be retried. 38 | rq.Add(obj3, 3, false, err) 39 | 40 | <-rq.Wait() 41 | item1, ok := rq.Top() 42 | if assert.True(t, ok) { 43 | assert.True(t, item1.retryAt.Before(time.Now()), "expected item to be expired") 44 | assert.Equal(t, item1.object, obj1) 45 | 46 | rq.Pop() 47 | rq.Clear(item1.object) 48 | } 49 | 50 | <-rq.Wait() 51 | item2, ok := rq.Top() 52 | if assert.True(t, ok) { 53 | assert.True(t, item2.retryAt.Before(time.Now()), "expected item to be expired") 54 | assert.True(t, item2.retryAt.After(item1.retryAt), "expected item to expire later than previous") 55 | assert.Equal(t, item2.object, obj2) 56 | 57 | rq.Pop() 58 | rq.Clear(item2.object) 59 | } 60 | 61 | // Pop the last object. But don't clear its retry count. 62 | <-rq.Wait() 63 | item3, ok := rq.Top() 64 | if assert.True(t, ok) { 65 | assert.True(t, item3.retryAt.Before(time.Now()), "expected item to be expired") 66 | assert.True(t, item3.retryAt.After(item2.retryAt), "expected item to expire later than previous") 67 | assert.Equal(t, item3.object, obj3) 68 | 69 | rq.Pop() 70 | } 71 | 72 | // Queue should be empty now. 73 | item, ok := rq.Top() 74 | assert.False(t, ok) 75 | 76 | // Retry 'obj3' and since it was added back without clearing it'll be retried 77 | // later. Add obj1 and check that 'obj3' has later retry time. 78 | rq.Add(obj3, 4, false, err) 79 | rq.Add(obj1, 5, false, err) 80 | 81 | <-rq.Wait() 82 | item4, ok := rq.Top() 83 | if assert.True(t, ok) { 84 | assert.True(t, item4.retryAt.Before(time.Now()), "expected item to be expired") 85 | assert.Equal(t, item4.object, obj1) 86 | 87 | rq.Pop() 88 | rq.Clear(item4.object) 89 | } 90 | 91 | <-rq.Wait() 92 | item5, ok := rq.Top() 93 | if assert.True(t, ok) { 94 | assert.True(t, item5.retryAt.Before(time.Now()), "expected item to be expired") 95 | assert.True(t, item5.retryAt.After(item4.retryAt), "expected obj1 before obj3") 96 | assert.Equal(t, obj3, item5.object) 97 | 98 | // numRetries is 3 since 'obj3' was added to the queue 3 times and it has not 99 | // been cleared. 100 | assert.Equal(t, 3, item5.numRetries) 101 | 102 | rq.Pop() 103 | rq.Clear(item5.object) 104 | } 105 | 106 | _, ok = rq.Top() 107 | assert.False(t, ok) 108 | 109 | // Test that object can be cleared from the queue without popping it. 110 | rq.Add(obj1, 6, false, err) 111 | rq.Add(obj2, 7, false, err) 112 | rq.Add(obj3, 8, false, err) 113 | 114 | rq.Clear(obj1) // Remove obj1, testing that it'll fix the queue correctly. 115 | rq.Pop() // Pop and remove obj2 and clear it to test that Clear doesn't mess with queue 116 | rq.Clear(obj2) 117 | item, ok = rq.Top() 118 | if assert.True(t, ok) { 119 | rq.Pop() 120 | rq.Clear(item.object) 121 | assert.Equal(t, item.object, obj3) 122 | } 123 | _, ok = rq.Top() 124 | assert.False(t, ok) 125 | } 126 | 127 | func TestExponentialBackoff(t *testing.T) { 128 | backoff := exponentialBackoff{ 129 | min: time.Millisecond, 130 | max: time.Second, 131 | } 132 | 133 | for i := 0; i < 1000; i++ { 134 | dur := backoff.Duration(i) 135 | require.GreaterOrEqual(t, dur, backoff.min) 136 | require.LessOrEqual(t, dur, backoff.max) 137 | } 138 | require.Equal(t, backoff.Duration(0)*2, backoff.Duration(1)) 139 | } 140 | -------------------------------------------------------------------------------- /reconciler/status_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package reconciler 5 | 6 | import ( 7 | "encoding/json" 8 | "errors" 9 | "regexp" 10 | "testing" 11 | "time" 12 | 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestStatusString(t *testing.T) { 17 | now := time.Now() 18 | 19 | s := Status{ 20 | Kind: StatusKindPending, 21 | UpdatedAt: now, 22 | Error: "", 23 | } 24 | assert.Regexp(t, `Pending \([0-9]+\.[0-9]+.+s ago\)`, s.String()) 25 | s.UpdatedAt = now.Add(-time.Hour) 26 | assert.Regexp(t, `Pending \([0-9]+\.[0-9]+h ago\)`, s.String()) 27 | 28 | s = Status{ 29 | Kind: StatusKindDone, 30 | UpdatedAt: now, 31 | Error: "", 32 | } 33 | assert.Regexp(t, `Done \([0-9]+\.[0-9]+.+s ago\)`, s.String()) 34 | 35 | s = Status{ 36 | Kind: StatusKindError, 37 | UpdatedAt: now, 38 | Error: "hey I'm an error", 39 | } 40 | assert.Regexp(t, `Error: hey I'm an error \([0-9]+\.[0-9]+.+s ago\)`, s.String()) 41 | } 42 | 43 | func TestStatusJSON(t *testing.T) { 44 | testCases := []struct { 45 | s Status 46 | expected string 47 | }{ 48 | { 49 | Status{ 50 | Kind: StatusKindDone, 51 | UpdatedAt: time.Unix(1, 0).UTC(), 52 | Error: "", 53 | }, 54 | `{"kind":"Done","updated-at":"1970-01-01T00:00:01Z"}`, 55 | }, 56 | { 57 | Status{ 58 | Kind: StatusKindPending, 59 | UpdatedAt: time.Unix(2, 0).UTC(), 60 | Error: "", 61 | }, 62 | `{"kind":"Pending","updated-at":"1970-01-01T00:00:02Z"}`, 63 | }, 64 | { 65 | Status{ 66 | Kind: StatusKindError, 67 | UpdatedAt: time.Unix(3, 0).UTC(), 68 | Error: "some-error", 69 | }, 70 | `{"kind":"Error","updated-at":"1970-01-01T00:00:03Z","error":"some-error"}`, 71 | }, 72 | { 73 | Status{ 74 | Kind: StatusKindRefreshing, 75 | UpdatedAt: time.Unix(4, 0).UTC(), 76 | Error: "", 77 | }, 78 | `{"kind":"Refreshing","updated-at":"1970-01-01T00:00:04Z"}`, 79 | }, 80 | } 81 | 82 | for _, tc := range testCases { 83 | b, err := json.Marshal(tc.s) 84 | assert.NoError(t, err, "Marshal") 85 | assert.Equal(t, tc.expected, string(b)) 86 | 87 | var s Status 88 | assert.NoError(t, json.Unmarshal(b, &s), "Unmarshal") 89 | assert.Equal(t, tc.s, s) 90 | } 91 | 92 | } 93 | 94 | func sanitizeAgo(s string) string { 95 | r := regexp.MustCompile(`\(.* ago\)`) 96 | return string(r.ReplaceAll([]byte(s), []byte("(??? ago)"))) 97 | } 98 | 99 | func TestStatusSet(t *testing.T) { 100 | assertJSONRoundtrip := func(s StatusSet) { 101 | data, err := json.Marshal(s) 102 | assert.NoError(t, err, "Marshal") 103 | var s2 StatusSet 104 | err = json.Unmarshal(data, &s2) 105 | assert.NoError(t, err, "Unmarshal") 106 | assert.Equal(t, sanitizeAgo(s.String()), sanitizeAgo(s2.String())) 107 | } 108 | 109 | set := NewStatusSet() 110 | assert.Equal(t, "Pending", set.String()) 111 | assertJSONRoundtrip(set) 112 | 113 | s := set.Get("foo") 114 | assert.Equal(t, s.Kind, StatusKindPending) 115 | assert.NotZero(t, s.ID) 116 | 117 | set = set.Set("foo", StatusDone()) 118 | set = set.Set("bar", StatusError(errors.New("fail"))) 119 | assertJSONRoundtrip(set) 120 | 121 | assert.Equal(t, set.Get("foo").Kind, StatusKindDone) 122 | assert.Equal(t, set.Get("bar").Kind, StatusKindError) 123 | assert.Regexp(t, "^Errored: bar \\(fail\\), Done: foo \\(.* ago\\)", set.String()) 124 | 125 | set = set.Pending() 126 | assert.NotZero(t, set.Get("foo").ID) 127 | assert.Equal(t, set.Get("foo").Kind, StatusKindPending) 128 | assert.Equal(t, set.Get("bar").Kind, StatusKindPending) 129 | assert.Equal(t, set.Get("baz").Kind, StatusKindPending) 130 | assert.Regexp(t, "^Pending: bar foo \\(.* ago\\)", set.String()) 131 | assertJSONRoundtrip(set) 132 | } 133 | -------------------------------------------------------------------------------- /reconciler/testdata/batching.txtar: -------------------------------------------------------------------------------- 1 | # Test the incremental reconciliation with 2 | # batching. 3 | 4 | hive start 5 | start-reconciler with-batchops 6 | 7 | # From here this is the same as incremental.txtar. 8 | 9 | # Step 1: Insert non-faulty objects 10 | db/insert test-objects obj1.yaml 11 | db/insert test-objects obj2.yaml 12 | db/insert test-objects obj3.yaml 13 | db/cmp test-objects step1+3.table 14 | expect-ops update(1) update(2) update(3) 15 | 16 | # Reconciler should be running and reporting health 17 | health 'job-reconcile.*level=OK.*message=OK, 3 object' 18 | 19 | # Step 2: Update object '1' to be faulty and check that it fails and is being 20 | # retried. 21 | db/insert test-objects obj1_faulty.yaml 22 | expect-ops 'update(1) fail' 'update(1) fail' 23 | db/cmp test-objects step2.table 24 | health 'job-reconcile.*level=Degraded.*1 error' 25 | 26 | # Step 3: Set object '1' back to healthy state 27 | db/insert test-objects obj1.yaml 28 | expect-ops 'update(1)' 29 | db/cmp test-objects step1+3.table 30 | health 'job-reconcile.*level=OK' 31 | 32 | # Step 4: Delete '1' and '2' 33 | db/delete test-objects obj1.yaml 34 | db/delete test-objects obj2.yaml 35 | db/cmp test-objects step4.table 36 | expect-ops 'delete(1)' 'delete(2)' 37 | 38 | # Step 5: Try to delete '3' with faulty target 39 | set-faulty true 40 | db/delete test-objects obj3.yaml 41 | db/cmp test-objects empty.table 42 | expect-ops 'delete(3) fail' 43 | health 'job-reconcile.*level=Degraded.*1 error' 44 | 45 | # Step 6: Set the target back to healthy 46 | set-faulty false 47 | expect-ops 'delete(3)' 48 | health 'job-reconcile.*level=OK.*message=OK, 0 object' 49 | 50 | # Check metrics 51 | expvar 52 | ! stdout 'reconciliation_count.test: 0$' 53 | stdout 'reconciliation_current_errors.test: 0$' 54 | ! stdout 'reconciliation_total_errors.test: 0$' 55 | ! stdout 'reconciliation_duration.test/update: 0$' 56 | ! stdout 'reconciliation_duration.test/delete: 0$' 57 | 58 | # ------------ 59 | 60 | -- empty.table -- 61 | ID StatusKind 62 | 63 | -- step1+3.table -- 64 | ID StatusKind StatusError 65 | 1 Done 66 | 2 Done 67 | 3 Done 68 | 69 | -- step2.table -- 70 | ID StatusKind StatusError 71 | 1 Error update fail 72 | 2 Done 73 | 3 Done 74 | 75 | -- step4.table -- 76 | ID StatusKind 77 | 3 Done 78 | 79 | -- step7.table -- 80 | ID Faulty StatusKind StatusError 81 | 4 true Error update fail 82 | 5 false Done 83 | 84 | 85 | -- step8.table -- 86 | ID Faulty StatusKind 87 | 4 false Done 88 | 5 false Done 89 | 90 | 91 | -- obj1.yaml -- 92 | id: 1 93 | faulty: false 94 | status: 95 | kind: Pending 96 | id: 1 97 | 98 | -- obj1_faulty.yaml -- 99 | id: 1 100 | faulty: true 101 | status: 102 | kind: Pending 103 | id: 2 104 | 105 | -- obj2.yaml -- 106 | id: 2 107 | faulty: false 108 | status: 109 | kind: Pending 110 | id: 3 111 | 112 | -- obj2_faulty.yaml -- 113 | id: 2 114 | faulty: true 115 | status: 116 | kind: Pending 117 | id: 4 118 | 119 | -- obj3.yaml -- 120 | id: 3 121 | faulty: false 122 | status: 123 | kind: Pending 124 | id: 5 125 | 126 | -- obj3_faulty.yaml -- 127 | id: 3 128 | faulty: true 129 | status: 130 | kind: Pending 131 | id: 6 132 | 133 | -------------------------------------------------------------------------------- /reconciler/testdata/incremental.txtar: -------------------------------------------------------------------------------- 1 | # Test the incremental reconciliation with non-batched operations 2 | # and without pruning. 3 | 4 | hive start 5 | start-reconciler 6 | 7 | # Step 1: Insert non-faulty objects 8 | db/insert test-objects obj1.yaml 9 | db/insert test-objects obj2.yaml 10 | db/insert test-objects obj3.yaml 11 | db/cmp test-objects step1+3.table 12 | expect-ops update(1) update(2) update(3) 13 | 14 | # Reconciler should be running and reporting health 15 | health 'job-reconcile.*level=OK.*message=OK, 3 object' 16 | 17 | # Step 2: Update object '1' to be faulty and check that it fails and is being 18 | # retried. 19 | db/insert test-objects obj1_faulty.yaml 20 | db/cmp test-objects step2.table 21 | expect-ops 'update(1) fail' 'update(1) fail' 22 | health 'job-reconcile.*level=Degraded.*1 error' 23 | 24 | # Step 3: Set object '1' back to healthy state 25 | db/insert test-objects obj1.yaml 26 | db/show test-objects 27 | db/cmp test-objects step1+3.table 28 | expect-ops 'update(1)' 29 | health 'job-reconcile.*level=OK' 30 | 31 | # Step 4: Delete '1' and '2' 32 | db/delete test-objects obj1.yaml 33 | db/delete test-objects obj2.yaml 34 | db/cmp test-objects step4.table 35 | expect-ops 'delete(1)' 'delete(2)' 36 | 37 | # Step 5: Try to delete '3' with faulty target 38 | set-faulty true 39 | db/delete test-objects obj3.yaml 40 | db/cmp test-objects empty.table 41 | expect-ops 'delete(3) fail' 42 | health 'job-reconcile.*level=Degraded.*1 error' 43 | 44 | # Step 6: Set the target back to healthy 45 | set-faulty false 46 | expect-ops 'delete(3)' 47 | health 'job-reconcile.*level=OK.*message=OK, 0 object' 48 | 49 | # Check metrics 50 | expvar 51 | ! stdout 'reconciliation_count.test: 0$' 52 | stdout 'reconciliation_current_errors.test: 0$' 53 | ! stdout 'reconciliation_total_errors.test: 0$' 54 | ! stdout 'reconciliation_duration.test/update: 0$' 55 | ! stdout 'reconciliation_duration.test/delete: 0$' 56 | 57 | # ------------ 58 | 59 | -- empty.table -- 60 | ID StatusKind 61 | 62 | -- step1+3.table -- 63 | ID StatusKind StatusError 64 | 1 Done 65 | 2 Done 66 | 3 Done 67 | 68 | -- step2.table -- 69 | ID StatusKind StatusError 70 | 1 Error update fail 71 | 2 Done 72 | 3 Done 73 | 74 | -- step4.table -- 75 | ID StatusKind 76 | 3 Done 77 | 78 | -- step7.table -- 79 | ID Faulty StatusKind StatusError 80 | 4 true Error update fail 81 | 5 false Done 82 | 83 | -- step8.table -- 84 | ID Faulty StatusKind 85 | 4 false Done 86 | 5 false Done 87 | 88 | -- obj1.yaml -- 89 | id: 1 90 | faulty: false 91 | status: 92 | kind: Pending 93 | id: 1 94 | 95 | -- obj1_faulty.yaml -- 96 | id: 1 97 | faulty: true 98 | status: 99 | kind: Pending 100 | id: 2 101 | 102 | -- obj2.yaml -- 103 | id: 2 104 | faulty: false 105 | status: 106 | kind: Pending 107 | id: 3 108 | 109 | -- obj2_faulty.yaml -- 110 | id: 2 111 | faulty: true 112 | status: 113 | kind: Pending 114 | id: 4 115 | 116 | -- obj3.yaml -- 117 | id: 3 118 | faulty: false 119 | status: 120 | kind: Pending 121 | id: 5 122 | 123 | -- obj3_faulty.yaml -- 124 | id: 3 125 | faulty: true 126 | status: 127 | kind: Pending 128 | id: 6 129 | 130 | -------------------------------------------------------------------------------- /reconciler/testdata/prune_empty.txtar: -------------------------------------------------------------------------------- 1 | hive start 2 | start-reconciler with-prune 3 | 4 | # Pruning happens when table is initialized even without any objects. 5 | mark-init 6 | expect-ops prune(n=0) 7 | health 'job-reconcile.*level=OK' 8 | -------------------------------------------------------------------------------- /reconciler/testdata/pruning.txtar: -------------------------------------------------------------------------------- 1 | hive start 2 | start-reconciler with-prune 3 | 4 | # Pruning without table being initialized does nothing. 5 | db/insert test-objects obj1.yaml 6 | expect-ops update(1) 7 | prune 8 | db/insert test-objects obj2.yaml 9 | expect-ops update(2) update(1) 10 | health 'job-reconcile.*level=OK' 11 | 12 | # After init pruning happens immediately 13 | mark-init 14 | expect-ops prune(n=2) 15 | health 'job-reconcile.*level=OK' 16 | expvar 17 | ! stdout 'prune_count.test: 0' 18 | 19 | # Pruning with faulty ops will mark status as degraded 20 | set-faulty true 21 | prune 22 | expect-ops 'prune(n=2) fail' 23 | health 'job-reconcile.*level=Degraded.*message=.*prune fail' 24 | expvar 25 | stdout 'prune_current_errors.test: 1' 26 | 27 | # Pruning again with healthy ops fixes the status. 28 | set-faulty false 29 | prune 30 | expect-ops 'prune(n=2)' 31 | health 'job-reconcile.*level=OK' 32 | expvar 33 | stdout 'prune_current_errors.test: 0' 34 | 35 | # Delete an object and check pruning happens without it 36 | db/delete test-objects obj1.yaml 37 | prune 38 | expect-ops 'prune(n=1)' delete(1) 39 | 40 | # Prune without objects 41 | db/delete test-objects obj2.yaml 42 | prune 43 | expect-ops prune(n=0) delete(2) prune(n=1) 44 | 45 | # Check metrics 46 | expvar 47 | ! stdout 'prune_count.test: 0' 48 | stdout 'prune_current_errors.test: 0' 49 | stdout 'prune_total_errors.test: 1' 50 | ! stdout 'prune_duration.test: 0$' 51 | ! stdout 'reconciliation_count.test: 0$' 52 | stdout 'reconciliation_current_errors.test: 0$' 53 | stdout 'reconciliation_total_errors.test: 0$' 54 | ! stdout 'reconciliation_duration.test/update: 0$' 55 | ! stdout 'reconciliation_duration.test/delete: 0$' 56 | 57 | -- obj1.yaml -- 58 | id: 1 59 | faulty: false 60 | status: 61 | kind: Pending 62 | id: 1 63 | 64 | -- obj2.yaml -- 65 | id: 2 66 | faulty: false 67 | status: 68 | kind: Pending 69 | id: 2 70 | 71 | -------------------------------------------------------------------------------- /reconciler/testdata/refresh.txtar: -------------------------------------------------------------------------------- 1 | hive start 2 | start-reconciler with-refresh 3 | 4 | # Step 1: Add a test object. 5 | db/insert test-objects obj1.yaml 6 | expect-ops 'update(1)' 7 | db/cmp test-objects step1.table 8 | 9 | # Step 2: Set the object as updated in the past to force refresh 10 | db/insert test-objects obj1_old.yaml 11 | expect-ops 'update-refresh(1)' 12 | 13 | # Step 3: Refresh with faulty target, should see fail & retries 14 | set-faulty true 15 | db/insert test-objects obj1_old.yaml 16 | expect-ops 'update-refresh(1) fail' 'update-refresh(1) fail' 17 | db/cmp test-objects step3.table 18 | health 19 | health 'job-reconcile.*Degraded' 20 | 21 | # Step 4: Back to health 22 | set-faulty false 23 | db/insert test-objects obj1_old.yaml 24 | expect-ops 'update-refresh(1)' 25 | db/cmp test-objects step4.table 26 | health 'job-reconcile.*OK, 1 object' 27 | 28 | # ----- 29 | -- step1.table -- 30 | ID StatusKind 31 | 1 Done 32 | 33 | -- step3.table -- 34 | ID StatusKind 35 | 1 Error 36 | 37 | -- step4.table -- 38 | ID StatusKind 39 | 1 Done 40 | 41 | -- obj1.yaml -- 42 | id: 1 43 | faulty: false 44 | updates: 1 45 | status: 46 | kind: Pending 47 | updatedat: 2024-01-01T10:10:10.0+02:00 48 | error: "" 49 | id: 2 50 | 51 | -- obj1_old.yaml -- 52 | id: 1 53 | faulty: false 54 | updates: 1 55 | status: 56 | kind: Done 57 | updatedat: 2000-01-01T10:10:10.0+02:00 58 | error: "" 59 | id: 1 60 | 61 | -------------------------------------------------------------------------------- /regression_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package statedb 5 | 6 | import ( 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | 13 | "github.com/cilium/statedb/index" 14 | ) 15 | 16 | // Test_Regression_29324 tests that Get() on a index.String-based 17 | // unique index only returns exact matches. 18 | // https://github.com/cilium/cilium/issues/29324 19 | func Test_Regression_29324(t *testing.T) { 20 | type object struct { 21 | ID string 22 | Tag string 23 | } 24 | idIndex := Index[object, string]{ 25 | Name: "id", 26 | FromObject: func(t object) index.KeySet { 27 | return index.NewKeySet(index.String(t.ID)) 28 | }, 29 | FromKey: index.String, 30 | Unique: true, 31 | } 32 | tagIndex := Index[object, string]{ 33 | Name: "tag", 34 | FromObject: func(t object) index.KeySet { 35 | return index.NewKeySet(index.String(t.Tag)) 36 | }, 37 | FromKey: index.String, 38 | Unique: false, 39 | } 40 | 41 | db, _, _ := newTestDB(t) 42 | table, err := NewTable("objects", idIndex, tagIndex) 43 | require.NoError(t, err) 44 | require.NoError(t, db.RegisterTable(table)) 45 | 46 | wtxn := db.WriteTxn(table) 47 | table.Insert(wtxn, object{"foo", "aa"}) 48 | table.Insert(wtxn, object{"foobar", "aaa"}) 49 | table.Insert(wtxn, object{"baz", "aaaa"}) 50 | wtxn.Commit() 51 | 52 | // Exact match should only return "foo" 53 | txn := db.ReadTxn() 54 | iter := table.List(txn, idIndex.Query("foo")) 55 | items := Collect(iter) 56 | if assert.Len(t, items, 1, "Get(\"foo\") should return one match") { 57 | assert.EqualValues(t, "foo", items[0].ID) 58 | } 59 | 60 | // Partial match on prefix should not return anything 61 | iter = table.List(txn, idIndex.Query("foob")) 62 | items = Collect(iter) 63 | assert.Len(t, items, 0, "Get(\"foob\") should return nothing") 64 | 65 | // Query on non-unique index should only return exact match 66 | iter = table.List(txn, tagIndex.Query("aa")) 67 | items = Collect(iter) 68 | if assert.Len(t, items, 1, "Get(\"aa\") on tags should return one match") { 69 | assert.EqualValues(t, "foo", items[0].ID) 70 | } 71 | 72 | // Partial match on prefix should not return anything on non-unique index 73 | iter = table.List(txn, idIndex.Query("a")) 74 | items = Collect(iter) 75 | assert.Len(t, items, 0, "Get(\"a\") should return nothing") 76 | } 77 | 78 | // The watch channel returned by Changes() must be a closed one if there 79 | // is anything left to iterate over. Otherwise on partial iteration we'll 80 | // wait on a watch channel that reflects the changes of a full iteration 81 | // and we might be stuck waiting even when there's unprocessed changes. 82 | func Test_Regression_Changes_Watch(t *testing.T) { 83 | db, table, _ := newTestDB(t) 84 | 85 | wtxn := db.WriteTxn(table) 86 | changeIter, err := table.Changes(wtxn) 87 | require.NoError(t, err, "Changes") 88 | wtxn.Commit() 89 | 90 | n := 0 91 | changes, watch := changeIter.Next(db.ReadTxn()) 92 | for change := range changes { 93 | t.Fatalf("did not expect changes, got: %v", change) 94 | } 95 | 96 | // The returned watch channel is closed on the first call to Next() 97 | // as there may have been changes to iterate and we want it to be 98 | // safe to either partially consume the changes or even block first 99 | // on the watch channel and only then consume. 100 | select { 101 | case <-watch: 102 | default: 103 | t.Fatalf("Changes() watch channel not closed") 104 | } 105 | 106 | // Calling Next() again now will get a proper non-closed watch channel. 107 | changes, watch = changeIter.Next(db.ReadTxn()) 108 | for change := range changes { 109 | t.Fatalf("did not expect changes, got: %v", change) 110 | } 111 | select { 112 | case <-watch: 113 | t.Fatalf("Changes() watch channel unexpectedly closed") 114 | default: 115 | } 116 | 117 | wtxn = db.WriteTxn(table) 118 | table.Insert(wtxn, testObject{ID: 1}) 119 | table.Insert(wtxn, testObject{ID: 2}) 120 | table.Insert(wtxn, testObject{ID: 3}) 121 | wtxn.Commit() 122 | 123 | // Observe the objects. 124 | select { 125 | case <-watch: 126 | case <-time.After(time.Second): 127 | t.Fatalf("Changes() watch channel not closed after inserts") 128 | } 129 | 130 | changes, watch = changeIter.Next(db.ReadTxn()) 131 | n = 0 132 | for change := range changes { 133 | require.False(t, change.Deleted, "not deleted") 134 | n++ 135 | } 136 | require.Equal(t, 3, n, "expected 3 objects") 137 | 138 | // Delete the objects 139 | wtxn = db.WriteTxn(table) 140 | require.NoError(t, table.DeleteAll(wtxn), "DeleteAll") 141 | wtxn.Commit() 142 | 143 | // Partially observe the changes 144 | <-watch 145 | changes, watch = changeIter.Next(db.ReadTxn()) 146 | for change := range changes { 147 | require.True(t, change.Deleted, "expected Deleted") 148 | break 149 | } 150 | 151 | // Calling Next again after partially consuming the iterator 152 | // should return a closed watch channel. 153 | changes, watch = changeIter.Next(db.ReadTxn()) 154 | select { 155 | case <-watch: 156 | case <-time.After(time.Second): 157 | t.Fatalf("Changes() watch channel not closed!") 158 | } 159 | 160 | // Consume the rest of the deletions. 161 | n = 1 162 | for change := range changes { 163 | require.True(t, change.Deleted, "expected Deleted") 164 | n++ 165 | } 166 | require.Equal(t, 3, n, "expected 3 deletions") 167 | } 168 | 169 | // Prefix and LowerBound searches on non-unique indexes did not properly check 170 | // whether the object was a false positive due to matching on the primary key part 171 | // of the composite key (). E.g. if the 172 | // composite keys were <1> and <2> then Prefix("aa") incorrectly 173 | // yielded the <1> as it matched partially the primary key . 174 | // 175 | // Also another issue existed with the ordering of the results due to there being 176 | // no separator between and parts of the composite key. 177 | // E.g. <1> and <2> were yielded in the incorrect order 178 | // <2> and <1>, which implied "aa" < "a"! 179 | func Test_Regression_Prefix_NonUnique(t *testing.T) { 180 | type object struct { 181 | ID string 182 | Tag string 183 | } 184 | idIndex := Index[object, string]{ 185 | Name: "id", 186 | FromObject: func(t object) index.KeySet { 187 | return index.NewKeySet(index.String(t.ID)) 188 | }, 189 | FromKey: index.String, 190 | Unique: true, 191 | } 192 | tagIndex := Index[object, string]{ 193 | Name: "tag", 194 | FromObject: func(t object) index.KeySet { 195 | return index.NewKeySet(index.String(t.Tag)) 196 | }, 197 | FromKey: index.String, 198 | Unique: false, 199 | } 200 | 201 | db, _, _ := newTestDB(t) 202 | table, err := NewTable("objects", idIndex, tagIndex) 203 | require.NoError(t, err) 204 | require.NoError(t, db.RegisterTable(table)) 205 | 206 | wtxn := db.WriteTxn(table) 207 | table.Insert(wtxn, object{"aa", "a"}) 208 | table.Insert(wtxn, object{"b", "bb"}) 209 | table.Insert(wtxn, object{"z", "b"}) 210 | wtxn.Commit() 211 | 212 | // The tag index has one object with tag "a", prefix searching 213 | // "aa" should return nothing. 214 | txn := db.ReadTxn() 215 | iter := table.Prefix(txn, tagIndex.Query("aa")) 216 | items := Collect(iter) 217 | assert.Len(t, items, 0, "Prefix(\"aa\") should return nothing") 218 | 219 | iter = table.Prefix(txn, tagIndex.Query("a")) 220 | items = Collect(iter) 221 | if assert.Len(t, items, 1, "Prefix(\"a\") on tags should return one match") { 222 | assert.EqualValues(t, "aa", items[0].ID) 223 | } 224 | 225 | // Check prefix search ordering: should be fully defined by the secondary key. 226 | iter = table.Prefix(txn, tagIndex.Query("b")) 227 | items = Collect(iter) 228 | if assert.Len(t, items, 2, "Prefix(\"b\") on tags should return two matches") { 229 | assert.EqualValues(t, "z", items[0].ID) 230 | assert.EqualValues(t, "b", items[1].ID) 231 | } 232 | 233 | // With LowerBound search on "aa" we should see tags "b" and "bb" (in that order) 234 | iter = table.LowerBound(txn, tagIndex.Query("aa")) 235 | items = Collect(iter) 236 | if assert.Len(t, items, 2, "LowerBound(\"aa\") on tags should return two matches") { 237 | assert.EqualValues(t, "z", items[0].ID) 238 | assert.EqualValues(t, "b", items[1].ID) 239 | } 240 | } 241 | -------------------------------------------------------------------------------- /script_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package statedb 5 | 6 | import ( 7 | "context" 8 | "maps" 9 | "slices" 10 | "strings" 11 | "testing" 12 | 13 | "github.com/cilium/hive" 14 | "github.com/cilium/hive/cell" 15 | "github.com/cilium/hive/hivetest" 16 | "github.com/cilium/hive/script" 17 | "github.com/cilium/hive/script/scripttest" 18 | "github.com/stretchr/testify/assert" 19 | "github.com/stretchr/testify/require" 20 | ) 21 | 22 | func TestScript(t *testing.T) { 23 | log := hivetest.Logger(t) 24 | h := hive.New( 25 | Cell, // DB 26 | cell.Invoke(func(db *DB) { 27 | t1 := newTestObjectTable(t, "test1", tagsIndex) 28 | require.NoError(t, db.RegisterTable(t1), "RegisterTable") 29 | t2 := newTestObjectTable(t, "test2", tagsIndex) 30 | require.NoError(t, db.RegisterTable(t2), "RegisterTable") 31 | }), 32 | ) 33 | t.Cleanup(func() { 34 | assert.NoError(t, h.Stop(log, context.TODO())) 35 | }) 36 | cmds, err := h.ScriptCommands(log) 37 | require.NoError(t, err, "ScriptCommands") 38 | maps.Insert(cmds, maps.All(script.DefaultCmds())) 39 | engine := &script.Engine{ 40 | Cmds: cmds, 41 | } 42 | scripttest.Test(t, 43 | context.Background(), 44 | func(t testing.TB, args []string) *script.Engine { 45 | return engine 46 | }, 47 | []string{}, 48 | "testdata/*.txtar", 49 | ) 50 | } 51 | 52 | func TestHeaderLine(t *testing.T) { 53 | type retrieval struct { 54 | header string 55 | idxs []int 56 | } 57 | testCases := []struct { 58 | line string 59 | names []string 60 | pos []int 61 | get []retrieval 62 | }{ 63 | { 64 | "Foo Bar ", 65 | []string{"Foo", "Bar"}, 66 | []int{0, 6}, 67 | []retrieval{ 68 | {"Foo", []int{0}}, 69 | {"Bar", []int{1}}, 70 | {"Bar Foo Bar", []int{1, 0, 1}}, 71 | }, 72 | }, 73 | { 74 | "Foo Bar Quux", 75 | []string{"Foo", "Bar", "Quux"}, 76 | []int{0, 4, 10}, 77 | []retrieval{ 78 | {"Foo", []int{0}}, 79 | {"Bar", []int{1}}, 80 | {"Bar Foo", []int{1, 0}}, 81 | {"Quux", []int{2}}, 82 | {"Quux Foo", []int{2, 0}}, 83 | }, 84 | }, 85 | } 86 | 87 | for _, tc := range testCases { 88 | // Parse header line into names and positions 89 | names, pos := splitHeaderLine(tc.line) 90 | require.Equal(t, tc.names, names) 91 | require.Equal(t, tc.pos, pos) 92 | 93 | // Split the header line with the parsed positions. 94 | header := splitByPositions(tc.line, pos) 95 | require.Equal(t, tc.names, header) 96 | 97 | // Join the headers with the positions. 98 | line := joinByPositions(header, pos) 99 | require.Equal(t, strings.TrimRight(tc.line, " \t"), line) 100 | 101 | // Test retrievals 102 | for _, r := range tc.get { 103 | names, pos = splitHeaderLine(r.header) 104 | idxs, err := getColumnIndexes(names, header) 105 | require.NoError(t, err) 106 | require.Equal(t, r.idxs, idxs) 107 | 108 | row := slices.Clone(header) 109 | cols := takeColumns(row, idxs) 110 | line := joinByPositions(cols, pos) 111 | require.Equal(t, line, r.header) 112 | } 113 | } 114 | } 115 | 116 | func TestSortedFlags(t *testing.T) { 117 | cases := []struct{ input, expected string }{ 118 | {"foo bar", "foo bar"}, 119 | {"bar foo", "bar foo"}, 120 | {"foo bar -baz=1", "-baz=1 foo bar"}, 121 | {"bar foo -baz=1", "-baz=1 bar foo"}, 122 | {"foo -baz=1 bar", "-baz=1 foo bar"}, 123 | {"-baz=1 foo bar", "-baz=1 foo bar"}, 124 | {"-baz=1 foo -quux=2 bar", "-baz=1 -quux=2 foo bar"}, 125 | {"-baz=1 bar foo -quux=2", "-baz=1 -quux=2 bar foo"}, 126 | } 127 | 128 | for _, tc := range cases { 129 | actual := strings.Join(sortedArgs(strings.Split(tc.input, " ")), " ") 130 | assert.Equal(t, tc.expected, actual) 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /testdata/db.txtar: -------------------------------------------------------------------------------- 1 | # 2 | # This file is invoked by 'script_test.go' and tests the StateDB script commands 3 | # defined in 'script.go'. 4 | # 5 | 6 | hive start 7 | 8 | # Show the registered tables 9 | db 10 | 11 | # Initialized 12 | db/initialized 13 | db/initialized test1 14 | db/initialized test1 test2 15 | 16 | # Show (empty) 17 | db/show test1 18 | db/show test2 19 | 20 | # Assert empty 21 | db/empty test1 test2 22 | 23 | # Insert 24 | db/insert test1 obj1.yaml 25 | db/insert test1 obj2.yaml 26 | db/insert test2 obj2.yaml 27 | 28 | # Assert not empty 29 | ! db/empty test1 test2 30 | 31 | # Show (non-empty) 32 | db/show test1 33 | stdout ^ID.*Tags 34 | stdout 1.*bar 35 | stdout 2.*baz 36 | db/show test2 37 | 38 | db/show --format=table test1 39 | stdout ^ID.*Tags 40 | stdout 1.*bar 41 | stdout 2.*baz 42 | 43 | db/show --format=table --columns=Tags test1 44 | stdout ^Tags$ 45 | stdout '^bar, foo$' 46 | stdout '^baz, foo$' 47 | 48 | db/show -f table -o test1_show.table test1 49 | cmp test1.table test1_show.table 50 | 51 | db/show --format=yaml --out=test1_show.yaml test1 52 | cmp test1.yaml test1_show.yaml 53 | 54 | db/show --format=json -o=test1_show.json test1 55 | cmp test1.json test1_show.json 56 | 57 | # Get 58 | db/get test2 2 59 | db/get --format=table test2 2 60 | stdout '^ID.*Tags$' 61 | stdout ^2.*baz 62 | db/get --format=table --columns=Tags test2 2 63 | stdout ^Tags$ 64 | stdout '^baz, foo$' 65 | db/get --format=json test2 2 66 | db/get --format=yaml test2 2 67 | db/get --format=yaml -o=obj2_get.yaml test2 2 68 | cmp obj2.yaml obj2_get.yaml 69 | 70 | db/get -i tags -f yaml -o obj1_get.yaml test1 bar 71 | cmp obj1.yaml obj1_get.yaml 72 | 73 | # List 74 | db/list -o=list.table test1 1 75 | cmp obj1.table list.table 76 | db/list -o=list.table test1 2 77 | cmp obj2.table list.table 78 | 79 | db/list -o list.table -i tags test1 bar 80 | cmp obj1.table list.table 81 | db/list -o=list.table -i=tags test1 baz 82 | cmp obj2.table list.table 83 | db/list --out=list.table --index=tags test1 foo 84 | cmp objs.table list.table 85 | 86 | db/list --format=table --index=tags --columns=Tags test1 foo 87 | stdout ^Tags$ 88 | stdout '^bar, foo$' 89 | stdout '^baz, foo$' 90 | 91 | # Prefix 92 | # uint64 so can't really prefix search meaningfully, unless 93 | # FromString() accomodates partial keys. 94 | db/prefix test1 1 95 | 96 | db/prefix -o=prefix.table --index=tags test1 ba 97 | cmp objs.table prefix.table 98 | 99 | # LowerBound 100 | db/lowerbound -o=lb.table test1 0 101 | cmp objs.table lb.table 102 | db/lowerbound -o=lb.table test1 1 103 | cmp objs.table lb.table 104 | db/lowerbound -o=lb.table test1 2 105 | cmp obj2.table lb.table 106 | db/lowerbound -o=lb.table test1 3 107 | cmp empty.table lb.table 108 | 109 | # Compare 110 | db/cmp test1 objs.table 111 | db/cmp test1 objs_ids.table 112 | db/cmp --grep=bar test1 obj1.table 113 | db/cmp --grep=baz test1 obj2.table 114 | 115 | # Delete 116 | db/delete test1 obj1.yaml 117 | db/cmp test1 obj2.table 118 | 119 | db/delete test1 obj2.yaml 120 | db/cmp test1 empty.table 121 | 122 | # Delete with get 123 | db/insert test1 obj1.yaml 124 | db/cmp test1 obj1.table 125 | db/get --delete test1 1 126 | db/cmp test1 empty.table 127 | 128 | # Delete with prefix 129 | db/insert test1 obj1.yaml 130 | db/insert test1 obj2.yaml 131 | db/cmp test1 objs.table 132 | db/prefix --index=tags --delete test1 fo 133 | db/cmp test1 empty.table 134 | 135 | # Delete with lowerbound 136 | db/insert test1 obj1.yaml 137 | db/insert test1 obj2.yaml 138 | db/cmp test1 objs.table 139 | db/lowerbound --index=id --delete test1 2 140 | db/cmp test1 obj1.table 141 | 142 | # Tables 143 | db 144 | 145 | # --------------------- 146 | 147 | -- obj1.yaml -- 148 | id: 1 149 | tags: 150 | - bar 151 | - foo 152 | -- obj2.yaml -- 153 | id: 2 154 | tags: 155 | - baz 156 | - foo 157 | -- test1.yaml -- 158 | id: 1 159 | tags: 160 | - bar 161 | - foo 162 | --- 163 | id: 2 164 | tags: 165 | - baz 166 | - foo 167 | -- test1.json -- 168 | { 169 | "ID": 1, 170 | "Tags": [ 171 | "bar", 172 | "foo" 173 | ] 174 | } 175 | { 176 | "ID": 2, 177 | "Tags": [ 178 | "baz", 179 | "foo" 180 | ] 181 | } 182 | -- test1.table -- 183 | ID Tags 184 | 1 bar, foo 185 | 2 baz, foo 186 | -- objs.table -- 187 | ID Tags 188 | 1 bar, foo 189 | 2 baz, foo 190 | -- objs_ids.table -- 191 | ID 192 | 1 193 | 2 194 | -- obj1.table -- 195 | ID Tags 196 | 1 bar, foo 197 | -- obj2.table -- 198 | ID Tags 199 | 2 baz, foo 200 | -- empty.table -- 201 | ID Tags 202 | -------------------------------------------------------------------------------- /watchset.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package statedb 5 | 6 | import ( 7 | "context" 8 | "reflect" 9 | "slices" 10 | "sync" 11 | "time" 12 | ) 13 | 14 | // WatchSet is a set of watch channels that can be waited on. 15 | type WatchSet struct { 16 | mu sync.Mutex 17 | chans channelSet 18 | 19 | cases []reflect.SelectCase // for reuse in Wait() 20 | } 21 | 22 | type channelSet = map[<-chan struct{}]struct{} 23 | 24 | func NewWatchSet() *WatchSet { 25 | return &WatchSet{ 26 | chans: channelSet{}, 27 | } 28 | } 29 | 30 | // Add channel(s) to the watch set 31 | func (ws *WatchSet) Add(chans ...<-chan struct{}) { 32 | ws.mu.Lock() 33 | for _, ch := range chans { 34 | ws.chans[ch] = struct{}{} 35 | } 36 | ws.mu.Unlock() 37 | } 38 | 39 | // Clear the channels from the WatchSet 40 | func (ws *WatchSet) Clear() { 41 | ws.mu.Lock() 42 | clear(ws.chans) 43 | ws.mu.Unlock() 44 | } 45 | 46 | // Has returns true if the WatchSet has the channel 47 | func (ws *WatchSet) Has(ch <-chan struct{}) bool { 48 | ws.mu.Lock() 49 | _, found := ws.chans[ch] 50 | ws.mu.Unlock() 51 | return found 52 | } 53 | 54 | // HasAny returns true if the WatchSet has any of the given channels 55 | func (ws *WatchSet) HasAny(chans []<-chan struct{}) bool { 56 | ws.mu.Lock() 57 | defer ws.mu.Unlock() 58 | for _, ch := range chans { 59 | if _, found := ws.chans[ch]; found { 60 | return true 61 | } 62 | } 63 | return false 64 | } 65 | 66 | // Merge channels from another WatchSet 67 | func (ws *WatchSet) Merge(other *WatchSet) { 68 | other.mu.Lock() 69 | defer other.mu.Unlock() 70 | ws.mu.Lock() 71 | defer ws.mu.Unlock() 72 | for ch := range other.chans { 73 | ws.chans[ch] = struct{}{} 74 | } 75 | } 76 | 77 | // Wait for channels in the watch set to close or the context is cancelled. 78 | // After the first closed channel is seen Wait will wait [settleTime] for 79 | // more closed channels. 80 | // Returns the closed channels and removes them from the set. 81 | func (ws *WatchSet) Wait(ctx context.Context, settleTime time.Duration) ([]<-chan struct{}, error) { 82 | innerCtx, cancel := context.WithTimeout(ctx, settleTime) 83 | defer cancel() 84 | 85 | ws.mu.Lock() 86 | defer ws.mu.Unlock() 87 | 88 | // No channels to watch? Just watch the context. 89 | if len(ws.chans) == 0 { 90 | <-ctx.Done() 91 | return nil, ctx.Err() 92 | } 93 | 94 | // Construct []SelectCase slice. Reuse the previous allocation. 95 | ws.cases = slices.Grow(ws.cases, 1+len(ws.chans)) 96 | cases := ws.cases[:1+len(ws.chans)] 97 | cases[0] = reflect.SelectCase{ 98 | Dir: reflect.SelectRecv, 99 | Chan: reflect.ValueOf(innerCtx.Done()), 100 | } 101 | casesIndex := 1 102 | for ch := range ws.chans { 103 | cases[casesIndex] = reflect.SelectCase{ 104 | Dir: reflect.SelectRecv, 105 | Chan: reflect.ValueOf(ch), 106 | } 107 | casesIndex++ 108 | } 109 | 110 | var closedChannels []<-chan struct{} 111 | 112 | // Collect closed channels until [innerCtx] is cancelled. 113 | for { 114 | chosen, _, _ := reflect.Select(cases) 115 | if chosen == 0 /* == innerCtx.Done() */ { 116 | break 117 | } 118 | closedChannels = append(closedChannels, cases[chosen].Chan.Interface().(<-chan struct{})) 119 | cases[chosen] = cases[len(cases)-1] 120 | cases = cases[:len(cases)-1] 121 | } 122 | 123 | // Remove the closed channels from the watch set. 124 | for _, ch := range closedChannels { 125 | delete(ws.chans, ch) 126 | } 127 | 128 | return closedChannels, ctx.Err() 129 | } 130 | -------------------------------------------------------------------------------- /watchset_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | // Copyright Authors of Cilium 3 | 4 | package statedb 5 | 6 | import ( 7 | "context" 8 | "math/rand" 9 | "testing" 10 | "time" 11 | 12 | "github.com/cilium/statedb/part" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestWatchSet(t *testing.T) { 17 | t.Parallel() 18 | // NOTE: TestMain calls goleak.VerifyTestMain so we know this test doesn't leak goroutines. 19 | 20 | ws := NewWatchSet() 21 | 22 | // Empty watch set, cancelled context. 23 | ctx, cancel := context.WithCancel(context.Background()) 24 | cancel() 25 | ch, err := ws.Wait(ctx, time.Second) 26 | require.ErrorIs(t, err, context.Canceled) 27 | require.Nil(t, ch) 28 | 29 | // Few channels, cancelled context. 30 | ch1 := make(chan struct{}) 31 | ch2 := make(chan struct{}) 32 | ch3 := make(chan struct{}) 33 | ws.Add(ch1, ch2, ch3) 34 | ctx, cancel = context.WithCancel(context.Background()) 35 | cancel() 36 | ch, err = ws.Wait(ctx, time.Second) 37 | require.ErrorIs(t, err, context.Canceled) 38 | require.Nil(t, ch) 39 | 40 | // Many channels 41 | for _, numChans := range []int{2, 16, 31, 1024} { 42 | var chans []chan struct{} 43 | var rchans []<-chan struct{} 44 | for range numChans { 45 | ch := make(chan struct{}) 46 | chans = append(chans, ch) 47 | rchans = append(rchans, ch) 48 | } 49 | ws.Clear() 50 | ws.Add(rchans...) 51 | 52 | i, j := rand.Intn(numChans), rand.Intn(numChans) 53 | for j == i { 54 | j = rand.Intn(numChans) 55 | } 56 | close(chans[i]) 57 | close(chans[j]) 58 | closed, err := ws.Wait(context.Background(), 50*time.Millisecond) 59 | require.NoError(t, err) 60 | require.ElementsMatch(t, closed, []<-chan struct{}{chans[i], chans[j]}, "i=%d, j=%d", i, j) 61 | cancel() 62 | } 63 | } 64 | 65 | func TestWatchSetInQueries(t *testing.T) { 66 | t.Parallel() 67 | db, table := newTestDBWithMetrics(t, &NopMetrics{}, tagsIndex) 68 | 69 | ws := NewWatchSet() 70 | txn := db.ReadTxn() 71 | _, watchAll := table.AllWatch(txn) 72 | 73 | // Should timeout as watches should not have closed yet. 74 | ws.Add(watchAll) 75 | ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) 76 | closed, err := ws.Wait(ctx, time.Second) 77 | require.ErrorIs(t, err, context.DeadlineExceeded) 78 | require.Empty(t, closed) 79 | cancel() 80 | 81 | // Insert some objects 82 | wtxn := db.WriteTxn(table) 83 | table.Insert(wtxn, testObject{ID: 1}) 84 | table.Insert(wtxn, testObject{ID: 2}) 85 | table.Insert(wtxn, testObject{ID: 3}) 86 | txn = wtxn.Commit() 87 | 88 | // The 'watchAll' channel should now have closed and Wait() returns. 89 | ws.Add(watchAll) 90 | closed, err = ws.Wait(context.Background(), 100*time.Millisecond) 91 | require.NoError(t, err) 92 | require.Len(t, closed, 1) 93 | require.True(t, closed[0] == watchAll) 94 | ws.Clear() 95 | 96 | // Try watching specific objects for changes. 97 | _, _, watch1, _ := table.GetWatch(txn, idIndex.Query(1)) 98 | _, _, watch2, _ := table.GetWatch(txn, idIndex.Query(2)) 99 | _, _, watch3, _ := table.GetWatch(txn, idIndex.Query(3)) 100 | 101 | wtxn = db.WriteTxn(table) 102 | table.Insert(wtxn, testObject{ID: 1, Tags: part.NewSet("foo")}) 103 | wtxn.Commit() 104 | 105 | // Use a new WatchSet and merge it. This allows having "subsets" that we 106 | // can then use to check whether the closed channel affected the subset. 107 | ws2 := NewWatchSet() 108 | ws2.Add(watch3, watch2, watch1) 109 | 110 | // Merge into the larger WatchSet. This still leaves all the channels 111 | // in ws2. 112 | ws.Merge(ws2) 113 | 114 | closed, err = ws.Wait(context.Background(), 100*time.Millisecond) 115 | require.NoError(t, err) 116 | require.Len(t, closed, 1) 117 | require.True(t, closed[0] == watch1) 118 | require.True(t, ws2.Has(closed[0])) 119 | require.True(t, ws2.HasAny(closed)) 120 | 121 | ws2.Clear() 122 | require.False(t, ws2.Has(closed[0])) 123 | } 124 | 125 | func benchmarkWatchSet(b *testing.B, numChans int) { 126 | ws := NewWatchSet() 127 | for range numChans - 1 { 128 | ws.Add(make(chan struct{})) 129 | } 130 | 131 | b.ResetTimer() 132 | for range b.N { 133 | ws.Add(closedWatchChannel) 134 | ws.Wait(context.TODO(), 0) 135 | } 136 | } 137 | 138 | func BenchmarkWatchSet_4(b *testing.B) { 139 | benchmarkWatchSet(b, 4) 140 | } 141 | 142 | func BenchmarkWatchSet_16(b *testing.B) { 143 | benchmarkWatchSet(b, 16) 144 | } 145 | 146 | func BenchmarkWatchSet_128(b *testing.B) { 147 | benchmarkWatchSet(b, 128) 148 | } 149 | 150 | func BenchmarkWatchSet_1024(b *testing.B) { 151 | benchmarkWatchSet(b, 1024) 152 | } 153 | --------------------------------------------------------------------------------