├── .github └── workflows │ ├── build.yml │ └── golangci-lint.yml ├── .gitignore ├── .golangci.yaml ├── CONTRIBUTING ├── LICENSE ├── Makefile ├── NOTICE ├── README.md ├── ads ├── ads.go ├── ads_example_test.go ├── ads_test.go ├── glob_collection_url.go └── glob_collection_url_test.go ├── cache.go ├── cache_test.go ├── client.go ├── client_test.go ├── doc.go ├── examples └── quickstart │ └── main.go ├── go.mod ├── go.sum ├── internal ├── cache │ ├── glob_collection.go │ ├── glob_collections_map.go │ ├── resource_map.go │ ├── subscriber_set.go │ ├── subscriber_set_test.go │ ├── subscription_type.go │ ├── subscription_type_test.go │ ├── watchable_value.go │ └── watchable_value_test.go ├── client │ └── watchers.go ├── server │ ├── handlers.go │ ├── handlers_bench_test.go │ ├── handlers_delta.go │ ├── handlers_delta_test.go │ ├── handlers_test.go │ ├── limiter.go │ ├── limiter_test.go │ └── subscription_manager.go └── utils │ ├── set.go │ ├── utils.go │ └── utils_test.go ├── server.go ├── server_test.go ├── stats └── server │ └── server_stats.go ├── test_xds_config.json ├── testutils ├── testutils.go └── testutils_test.go ├── type.go └── type_test.go /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Setup Go 18 | uses: actions/setup-go@v5 19 | with: 20 | go-version: '1.23.x' 21 | - name: Install 22 | run: go get -v . 23 | - name: Build 24 | run: make build 25 | - name: Test 26 | run: make test TESTVERBOSE=-v 27 | -------------------------------------------------------------------------------- /.github/workflows/golangci-lint.yml: -------------------------------------------------------------------------------- 1 | name: golangci-lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | 9 | permissions: 10 | contents: read 11 | # Optional: allow read access to pull request. Use with `only-new-issues` option. 12 | # pull-requests: read 13 | 14 | jobs: 15 | golangci: 16 | name: lint 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v4 20 | - uses: actions/setup-go@v5 21 | with: 22 | go-version: "1.24.1" 23 | - name: golangci-lint 24 | uses: golangci/golangci-lint-action@v6 25 | with: 26 | version: "v1.64.8" 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_store 2 | .*.swp 3 | .*.swo 4 | *.iml 5 | *.ipr 6 | *.iws 7 | *.sublime-* 8 | .direnv/ 9 | .gradle/ 10 | .idea/ 11 | .vscode/ 12 | *.prof 13 | .cov 14 | out/ 15 | -------------------------------------------------------------------------------- /.golangci.yaml: -------------------------------------------------------------------------------- 1 | linters: 2 | enable: 3 | - bodyclose 4 | - errname 5 | - errorlint 6 | - exhaustive 7 | - goconst 8 | - gofmt 9 | - goimports 10 | - gocritic 11 | - predeclared 12 | - usestdlibvars 13 | - unused 14 | -------------------------------------------------------------------------------- /CONTRIBUTING: -------------------------------------------------------------------------------- 1 | As a contributor, you represent that the code you submit is your original work or that of your employer (in which case you represent you have the right to bind your employer). By submitting code, you (and, if applicable, your employer) are licensing the submitted code to LinkedIn and the open source community subject to the BSD 2-Clause license. 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-CLAUSE LICENSE 2 | 3 | Copyright 2024 LinkedIn Corporation 4 | All Rights Reserved. 5 | 6 | Redistribution and use in source and binary forms, with or 7 | without modification, are permitted provided that the following 8 | conditions are met: 9 | 10 | 1. Redistributions of source code must retain the above copyright 11 | notice, this list of conditions and the following disclaimer. 12 | 13 | 2. Redistributions in binary form must reproduce the above 14 | copyright notice, this list of conditions and the following 15 | disclaimer in the documentation and/or other materials provided 16 | with the distribution. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PACKAGE = github.com/linkedin/diderot 2 | SOURCE_FILES = $(wildcard $(shell git ls-files)) 3 | PROFILES = $(PWD)/out 4 | COVERAGE = $(PROFILES)/diderot.cov 5 | GOBIN = $(shell go env GOPATH)/bin 6 | 7 | # "all" is invoked on a bare "make" call since it's the first recipe. It just formats the code and 8 | # checks that all packages can be compiled 9 | .PHONY: all 10 | all: fmt build 11 | 12 | build: 13 | go build -tags=examples -v ./... 14 | go test -v -c -o /dev/null $$(go list -f '{{if .TestGoFiles}}{{.ImportPath}}{{end}}' ./...) 15 | 16 | tidy: 17 | go mod tidy 18 | 19 | vet: 20 | go vet ./... 21 | 22 | $(GOBIN)/goimports: 23 | go install golang.org/x/tools/cmd/goimports@latest 24 | 25 | .PHONY: fmt 26 | fmt: $(GOBIN)/goimports 27 | $(GOBIN)/goimports -w . 28 | 29 | # Can be used to change the number of tests run, defaults to 1 to prevent caching 30 | TESTCOUNT = 1 31 | # Can be used to add flags to the go test invocation: make test TESTFLAGS=-v 32 | TESTFLAGS = 33 | # Can be used to change which package gets tested, defaults to all packages. 34 | TESTPKG = ./... 35 | # Can be used to generate coverage reports for a specific package 36 | COVERPKG = . 37 | # The default coverage flags. Specifying COVERFLAGS= disables coverage, which meakes the tests run 38 | # faster. 39 | COVERFLAGS = -coverprofile=$(COVERAGE) -coverpkg=$(COVERPKG)/... 40 | 41 | test: 42 | @mkdir -p $(dir $(COVERAGE)) 43 | go test -v -race $(COVERFLAGS) -count=$(TESTCOUNT) $(TESTFLAGS) $(TESTPKG) 44 | go tool cover -func $(COVERAGE) | awk '/total:/{print "Coverage: "$$3}' 45 | 46 | .PHONY: $(COVERAGE) 47 | 48 | coverage: $(COVERAGE) 49 | $(COVERAGE): 50 | @mkdir -p $(@D) 51 | -$(MAKE) test 52 | go tool cover -html=$(COVERAGE) 53 | 54 | profile_cache: 55 | $(MAKE) -B $(PROFILES)/BenchmarkCacheThroughput.bench BENCH_PKG=. 56 | 57 | profile_handlers: 58 | $(MAKE) -B $(PROFILES)/BenchmarkHandlers.bench BENCH_PKG=./internal/server 59 | 60 | profile_typeurl: 61 | $(MAKE) -B $(PROFILES)/BenchmarkGetTrimmedTypeURL.bench BENCH_PKG=ads 62 | 63 | profile_parse_glob_urn: 64 | $(MAKE) -B $(PROFILES)/BenchmarkParseGlobCollectionURN.bench BENCH_PKG=ads 65 | 66 | profile_set_clear: 67 | $(MAKE) -B $(PROFILES)/BenchmarkValueSetClear.bench BENCH_PKG=./internal/cache 68 | 69 | profile_notification_loop: 70 | $(MAKE) -B $(PROFILES)/BenchmarkNotificationLoop.bench BENCH_PKG=./internal/cache 71 | 72 | BENCHCOUNT = 1 73 | BENCHTIME = 1s 74 | 75 | $(PROFILES)/%.bench: 76 | ifdef BENCH_PKG 77 | $(eval BENCHBIN=$(PROFILES)/$*) 78 | mkdir -p $(PROFILES) 79 | go test -c \ 80 | -o $(BENCHBIN) \ 81 | ./$(BENCH_PKG) 82 | cd $(BENCH_PKG) && $(BENCHBIN) \ 83 | -test.count $(BENCHCOUNT) \ 84 | -test.benchmem \ 85 | -test.bench="^$*$$" \ 86 | -test.cpuprofile $(PROFILES)/$*.cpu \ 87 | -test.memprofile $(PROFILES)/$*.mem \ 88 | -test.blockprofile $(PROFILES)/$*.block \ 89 | -test.benchtime $(BENCHTIME) \ 90 | -test.run "^$$" $(BENCHVERBOSE) \ 91 | . | tee $(abspath $@) $(abspath $(BENCHOUT)) 92 | else 93 | $(error BENCH_PKG undefined) 94 | endif 95 | ifdef OPEN_PROFILES 96 | go tool pprof -http : $(BENCHBIN) $(PROFILES)/$*.cpu & \ 97 | go tool pprof -http : $(PROFILES)/$*.mem ; kill %1 98 | else 99 | $(info Not opening profiles since OPEN_PROFILES is not set) 100 | endif 101 | 102 | $(GOBIN)/pkgsite: 103 | go install golang.org/x/pkgsite/cmd/pkgsite@latest 104 | 105 | docs: $(GOBIN)/pkgsite 106 | $(GOBIN)/pkgsite -open . 107 | 108 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright 2024 LinkedIn Corporation 2 | All Rights Reserved. 3 | 4 | Licensed under the BSD 2-Clause License (the "License"). See License in the project root for license information. 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diderot 2 | (pronounced dee-duh-row) 3 | 4 | --- 5 | 6 | Diderot is a server implementation of 7 | the [xDS protocol](https://www.envoyproxy.io/docs/envoy/latest/api-docs/xds_protocol) that makes it extremely easy and 8 | efficient to implement a control plane for your Envoy and gRPC services. For the most up-to-date information, please 9 | visit the [documentation](https://pkg.go.dev/github.com/linkedin/diderot). 10 | 11 | ## Quick Start Guide 12 | The only thing you need to implement to make your resources available via xDS is a 13 | `diderot.ResourceLocator`([link](https://pkg.go.dev/github.com/linkedin/diderot#ResourceLocator)). It is the interface 14 | exposed by the [ADS server implementation](https://pkg.go.dev/github.com/linkedin/diderot#ADSServer) which should 15 | contain the business logic of all your resource definitions and how to find them. To facilitate this implementation, 16 | Diderot provides an efficient, low-resource [cache](https://pkg.go.dev/github.com/linkedin/diderot#Cache) that supports 17 | highly concurrent updates. By leveraging the cache implementation for the heavy lifting, you will be able to focus on 18 | the meaningful part of operating your own xDS control plane: your resource definitions. 19 | 20 | Once you have implemented your `ResourceLocator`, you can simply drop in a `diderot.ADSServer` to your gRPC service, and 21 | you're ready to go! Please refer to the [examples/quickstart](examples/quickstart/main.go) package 22 | 23 | ## Features 24 | Diderot's ADS server implementation is a faithful implementation of the xDS protocol. This means it implements both the 25 | State-of-the-World and Delta/Incremental variants. It supports advanced features such as 26 | [glob collections](https://github.com/cncf/xds/blob/main/proposals/TP1-xds-transport-next.md#glob), unlocking the more 27 | efficient alternative to the `EDS` stage: `LEDS` 28 | ([design doc](https://docs.google.com/document/d/1aZ9ddX99BOWxmfiWZevSB5kzLAfH2TS8qQDcCBHcfSE/edit#heading=h.mmb97owcrx3c)). 29 | 30 | -------------------------------------------------------------------------------- /ads/ads.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package ads provides a set of utilities and definitions around the Aggregated Discovery Service xDS 3 | protocol (ADS), such as convenient type aliases, constants and core definitions. 4 | */ 5 | package ads 6 | 7 | import ( 8 | "encoding/binary" 9 | "encoding/hex" 10 | "errors" 11 | "log/slog" 12 | "sync" 13 | "time" 14 | 15 | cluster "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" 16 | core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" 17 | endpoint "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" 18 | listener "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" 19 | route "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" 20 | tls "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" 21 | discovery "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" 22 | runtime "github.com/envoyproxy/go-control-plane/envoy/service/runtime/v3" 23 | types "github.com/envoyproxy/go-control-plane/pkg/resource/v3" 24 | "google.golang.org/protobuf/proto" 25 | "google.golang.org/protobuf/types/known/anypb" 26 | "google.golang.org/protobuf/types/known/durationpb" 27 | ) 28 | 29 | // Alias to xDS types, for convenience and brevity. 30 | type ( 31 | // Server is the core interface that needs to be implemented by an xDS control plane. The 32 | // "Aggregated" service (i.e. ADS, the name of this package) service is type agnostic, the desired 33 | // type is specified in the request. This avoids the need for clients to open multiple streams when 34 | // requesting different types, along with not needing new service definitions such as 35 | // [github.com/envoyproxy/go-control-plane/envoy/service/endpoint/v3.EndpointDiscoveryServiceServer]. 36 | Server = discovery.AggregatedDiscoveryServiceServer 37 | // Node is an alias for the client information included in both Delta and SotW requests [core.Node]. 38 | Node = core.Node 39 | 40 | // SotWClient is an alias for the state-of-the-world client type 41 | // [discovery.AggregatedDiscoveryService_StreamAggregatedResourcesClient]. 42 | SotWClient = discovery.AggregatedDiscoveryService_StreamAggregatedResourcesClient 43 | // SotWStream is an alias for the state-of-the-world stream type for the server 44 | // [discovery.AggregatedDiscoveryService_StreamAggregatedResourcesServer]. 45 | SotWStream = discovery.AggregatedDiscoveryService_StreamAggregatedResourcesServer 46 | // SotWDiscoveryRequest is an alias for the state-of-the-world request type 47 | // [discovery.DiscoveryRequest]. 48 | SotWDiscoveryRequest = discovery.DiscoveryRequest 49 | // SotWDiscoveryResponse is an alias for the state-of-the-world response type 50 | // [discovery.DiscoveryResponse]. 51 | SotWDiscoveryResponse = discovery.DiscoveryResponse 52 | 53 | // DeltaClient is an alias for the delta client type 54 | // [discovery.AggregatedDiscoveryService_DeltaAggregatedResourcesClient]. 55 | DeltaClient = discovery.AggregatedDiscoveryService_DeltaAggregatedResourcesClient 56 | // DeltaStream is an alias for the delta (also known as incremental) stream type for the server 57 | // [discovery.AggregatedDiscoveryService_DeltaAggregatedResourcesServer]. 58 | DeltaStream = discovery.AggregatedDiscoveryService_DeltaAggregatedResourcesServer 59 | // DeltaDiscoveryRequest is an alias for the delta request type [discovery.DeltaDiscoveryRequest]. 60 | DeltaDiscoveryRequest = discovery.DeltaDiscoveryRequest 61 | // DeltaDiscoveryResponse is an alias for the delta response type [discovery.DeltaDiscoveryResponse]. 62 | DeltaDiscoveryResponse = discovery.DeltaDiscoveryResponse 63 | 64 | // RawResource is a type alias for the core ADS type [discovery.Resource]. It is "raw" only in 65 | // contrast to the [*Resource] type defined in this package, which preserves the underlying 66 | // resource's type as a generic parameter. 67 | RawResource = discovery.Resource 68 | ) 69 | 70 | // NewResource is a convenience method for creating a new [*Resource]. 71 | func NewResource[T proto.Message](name, version string, t T) *Resource[T] { 72 | return &Resource[T]{ 73 | Name: name, 74 | Version: version, 75 | Resource: t, 76 | } 77 | } 78 | 79 | // Resource is the typed equivalent of [RawResource] in that it preserves the underlying resource's 80 | // type at compile time. It defines the same fields as [RawResource] (except for unsupported fields 81 | // such as [ads.RawResource.Aliases]), and can be trivially serialized to a [RawResource] with 82 | // Marshal. It is undefined behavior to modify a [Resource] after creation. 83 | type Resource[T proto.Message] struct { 84 | Name string 85 | Version string 86 | Resource T 87 | Ttl *durationpb.Duration 88 | CacheControl *discovery.Resource_CacheControl 89 | Metadata *core.Metadata 90 | 91 | marshalOnce sync.Once 92 | marshaled *RawResource 93 | marshalErr error 94 | } 95 | 96 | // Marshal returns the serialized version of this Resource. Note that this result is cached, and can 97 | // be called repeatedly and from multiple goroutines. 98 | func (r *Resource[T]) Marshal() (*RawResource, error) { 99 | r.marshalOnce.Do(func() { 100 | var out *anypb.Any 101 | out, r.marshalErr = anypb.New(r.Resource) 102 | if r.marshalErr != nil { 103 | // This shouldn't really ever happen, especially when serializing to Any 104 | slog.Error( 105 | "Failed to serialize proto", 106 | "msg", r.Resource, 107 | "type", string(r.Resource.ProtoReflect().Descriptor().FullName()), 108 | "err", r.marshalErr, 109 | ) 110 | return 111 | } 112 | 113 | r.marshaled = &RawResource{ 114 | Name: r.Name, 115 | Version: r.Version, 116 | Resource: out, 117 | Ttl: r.Ttl, 118 | CacheControl: r.CacheControl, 119 | Metadata: r.Metadata, 120 | } 121 | }) 122 | return r.marshaled, r.marshalErr 123 | } 124 | 125 | // TypeURL returns the underlying resource's type URL. 126 | func (r *Resource[T]) TypeURL() string { 127 | // A literal `Resource[proto.Message]` works well when the goal is to store the deserialized 128 | // [RawResource]. However, inferring the type URL cannot be done with [utils.GetTypeURL] since it 129 | // uses reflection on the generic type parameter, which in this case is just [proto.Message], which 130 | // causes a panic. Instead, infer the type directly from the deserialized message, avoiding the panic 131 | // if the type parameter is indeed simply proto.Message. 132 | return types.APITypePrefix + string(r.Resource.ProtoReflect().Descriptor().FullName()) 133 | } 134 | 135 | func (r *Resource[T]) Equals(other *Resource[T]) bool { 136 | if r == other { 137 | return true 138 | } 139 | if r == nil || other == nil { 140 | return false 141 | } 142 | return r.Name == other.Name && 143 | r.Version == other.Version && 144 | proto.Equal(r.Ttl, other.Ttl) && 145 | proto.Equal(r.CacheControl, other.CacheControl) && 146 | proto.Equal(r.Metadata, other.Metadata) && 147 | proto.Equal(r.Resource, other.Resource) 148 | } 149 | 150 | // UnmarshalRawResource unmarshals the given RawResource and returns a Resource of the corresponding 151 | // type. Resource.Marshal on the returned Resource will return the given RawResource instead of 152 | // re-serializing the resource. 153 | func UnmarshalRawResource[T proto.Message](raw *RawResource) (*Resource[T], error) { 154 | m, err := raw.Resource.UnmarshalNew() 155 | if err != nil { 156 | return nil, err 157 | } 158 | 159 | r := &Resource[T]{ 160 | Name: raw.Name, 161 | Version: raw.Version, 162 | Resource: m.(T), 163 | Ttl: raw.Ttl, 164 | CacheControl: raw.CacheControl, 165 | Metadata: raw.Metadata, 166 | } 167 | // Set marshaled using marshalOnce, otherwise the once will not be set and subsequent calls to 168 | // Marshal will serialize the resource, overwriting the field. 169 | r.marshalOnce.Do(func() { 170 | r.marshaled = raw 171 | }) 172 | 173 | return r, nil 174 | } 175 | 176 | const ( 177 | // WildcardSubscription is a special resource name that triggers a subscription to all resources of a 178 | // given type. 179 | WildcardSubscription = "*" 180 | // XDSTPScheme is the prefix for which all resource URNs (as defined in the [TP1 proposal]) start. 181 | // 182 | // [TP1 proposal]: https://github.com/cncf/xds/blob/main/proposals/TP1-xds-transport-next.md#uri-based-xds-resource-names 183 | XDSTPScheme = "xdstp://" 184 | ) 185 | 186 | // A SubscriptionHandler will receive notifications for the cache entries it has subscribed to using 187 | // RawCache.Subscribe. Note that it is imperative that implementations be hashable as it will be 188 | // stored as the key to a map (unhashable types include slices and functions). 189 | type SubscriptionHandler[T proto.Message] interface { 190 | // Notify is invoked when the given entry is modified. A deletion is denoted with a nil resource. The given time 191 | // parameters provides the time at which the client subscribed to the resource and the time at which the 192 | // modification happened respectively. Note that if an entry is modified repeatedly at a high rate, Notify will not 193 | // be invoked for all intermediate versions, though it will always *eventually* be invoked with the final version. 194 | Notify(name string, r *Resource[T], metadata SubscriptionMetadata) 195 | } 196 | 197 | // RawSubscriptionHandler is the untyped equivalent of SubscriptionHandler. 198 | type RawSubscriptionHandler interface { 199 | // Notify is the untyped equivalent of SubscriptionHandler.Notify. 200 | Notify(name string, r *RawResource, metadata SubscriptionMetadata) 201 | // ResourceMarshalError is invoked whenever a resource cannot be marshaled. This should be extremely 202 | // rare and requires immediate attention. When a resource cannot be marshaled, the notification will 203 | // be dropped and Notify will not be invoked. 204 | ResourceMarshalError(name string, resource proto.Message, err error) 205 | } 206 | 207 | // SubscriptionMetadata contains metadata about the subscription that triggered the Notify call on 208 | // the [RawSubscriptionHandler] or [SubscriptionHandler]. 209 | type SubscriptionMetadata struct { 210 | // The time at which the resource was subscribed to. 211 | SubscribedAt time.Time 212 | // The time at which the resource was modified (can be the 0-value if the modification time is unknown). 213 | ModifiedAt time.Time 214 | // The time at which the update to the resource was received by the cache (i.e. when [Cache.Set] was 215 | // called, not strictly when the server actually received the update). If this is metadata is for a 216 | // subscription to a resource that does not yet exist, will be the 0-value. 217 | CachedAt time.Time 218 | // The current priority index of the value. Will be 0 unless the backing cache was created with 219 | // [NewPrioritizedCache], [NewPrioritizedAggregateCache] or 220 | // [NewPrioritizedAggregateCachesByClientTypes]. If this metadata is for a subscription to a resource 221 | // that has been deleted (or does not yet exist), Priority will be the last valid index priority 222 | // index (because a resource is only considered deleted once it has been deleted from all cache 223 | // sources). For example, if the cache was created like this: 224 | // NewPrioritizedCache(10) 225 | // Then the last valid index is 9, since the slice of cache objects returned is of length 10. 226 | Priority int 227 | // The glob collection this resource belongs to, empty if it does not belong to any collections. 228 | GlobCollectionURL string 229 | } 230 | 231 | // These aliases mirror the constants declared in [github.com/envoyproxy/go-control-plane/pkg/resource/v3] 232 | type ( 233 | Endpoint = endpoint.ClusterLoadAssignment 234 | LbEndpoint = endpoint.LbEndpoint 235 | Cluster = cluster.Cluster 236 | Route = route.RouteConfiguration 237 | ScopedRoute = route.ScopedRouteConfiguration 238 | VirtualHost = route.VirtualHost 239 | Listener = listener.Listener 240 | Secret = tls.Secret 241 | ExtensionConfig = core.TypedExtensionConfig 242 | Runtime = runtime.Runtime 243 | ) 244 | 245 | // StreamType is an enum representing the different possible ADS stream types, SotW and Delta. 246 | type StreamType int 247 | 248 | const ( 249 | // UnknownStreamType is the 0-value, unknown stream type. 250 | UnknownStreamType StreamType = iota 251 | // DeltaStreamType is the delta/incremental variant of the ADS protocol. 252 | DeltaStreamType 253 | // SotWStreamType is the state-of-the-world variant of the ADS protocol. 254 | SotWStreamType 255 | ) 256 | 257 | var streamTypeStrings = [...]string{"UNKNOWN", "Delta", "SotW"} 258 | 259 | func (t StreamType) String() string { 260 | return streamTypeStrings[t] 261 | } 262 | 263 | // StreamTypes is an array containing the valid [StreamType] values. 264 | var StreamTypes = [...]StreamType{UnknownStreamType, DeltaStreamType, SotWStreamType} 265 | 266 | // LookupStreamTypeByRPCMethod checks whether the given RPC method string (usually acquired from 267 | // [google.golang.org/grpc.StreamServerInfo.FullMethod] in the context of a server stream 268 | // interceptor) is either [SotWStreamType] or [DeltaStreamType]. Returns ([UnknownStreamType], false) 269 | // if it is neither. 270 | func LookupStreamTypeByRPCMethod(rpcMethod string) (StreamType, bool) { 271 | switch rpcMethod { 272 | case discovery.AggregatedDiscoveryService_StreamAggregatedResources_FullMethodName: 273 | return SotWStreamType, true 274 | case discovery.AggregatedDiscoveryService_DeltaAggregatedResources_FullMethodName: 275 | return DeltaStreamType, true 276 | default: 277 | return UnknownStreamType, false 278 | } 279 | } 280 | 281 | var ( 282 | errInvalidNonceEncoding = errors.New("nonce isn't in hex encoding") 283 | errInvalidNonceLength = errors.New("decoded nonce did not have expected length") 284 | ) 285 | 286 | // ParseRemainingChunksFromNonce checks whether the Diderot server implementation chunked the delta 287 | // responses because not all resources could fit in the same response without going over the default 288 | // max gRPC message size of 4MB. A nonce from Diderot always starts with the 64-bit nanosecond 289 | // timestamp of when the response was generated on the server. Then the number of remaining chunks as 290 | // a 32-bit integer. The sequence of integers is binary encoded with [binary.BigEndian] then hex 291 | // encoded. If the given nonce does not match the expected format, this function simply returns 0 292 | // along with an error describing why it does not match. If the error isn't nil, it means the nonce 293 | // was not created by a Diderot server implementation, and therefore does not contain the expected 294 | // information. 295 | func ParseRemainingChunksFromNonce(nonce string) (remainingChunks int, err error) { 296 | decoded, err := hex.DecodeString(nonce) 297 | if err != nil { 298 | return 0, errInvalidNonceEncoding 299 | } 300 | 301 | if len(decoded) != 12 { 302 | return 0, errInvalidNonceLength 303 | } 304 | 305 | return int(binary.BigEndian.Uint32(decoded[8:12])), nil 306 | } 307 | -------------------------------------------------------------------------------- /ads/ads_example_test.go: -------------------------------------------------------------------------------- 1 | package ads_test 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/linkedin/diderot/ads" 7 | ) 8 | 9 | func ExampleParseRemainingChunksFromNonce() { 10 | // Acquire a delta ADS client 11 | var client ads.DeltaClient 12 | 13 | var responses []*ads.DeltaDiscoveryResponse 14 | for { 15 | res, err := client.Recv() 16 | if err != nil { 17 | log.Panicf("Error receiving delta response: %v", err) 18 | } 19 | responses = append(responses, res) 20 | 21 | if remaining, _ := ads.ParseRemainingChunksFromNonce(res.Nonce); remaining == 0 { 22 | break 23 | } 24 | } 25 | 26 | log.Printf("All responses received: %+v", responses) 27 | } 28 | -------------------------------------------------------------------------------- /ads/ads_test.go: -------------------------------------------------------------------------------- 1 | package ads 2 | 3 | import ( 4 | "testing" 5 | 6 | types "github.com/envoyproxy/go-control-plane/pkg/resource/v3" 7 | "github.com/linkedin/diderot/internal/utils" 8 | "github.com/stretchr/testify/require" 9 | "google.golang.org/protobuf/proto" 10 | ) 11 | 12 | func TestTypeURL(t *testing.T) { 13 | require.Panics(t, func() { 14 | utils.GetTypeURL[proto.Message]() 15 | }) 16 | 17 | r := NewResource[proto.Message]("foo", "0", new(Endpoint)) 18 | require.Equal(t, types.EndpointType, r.TypeURL()) 19 | } 20 | -------------------------------------------------------------------------------- /ads/glob_collection_url.go: -------------------------------------------------------------------------------- 1 | package ads 2 | 3 | import ( 4 | "errors" 5 | "net/url" 6 | "strings" 7 | 8 | types "github.com/envoyproxy/go-control-plane/pkg/resource/v3" 9 | "google.golang.org/protobuf/proto" 10 | ) 11 | 12 | // GlobCollectionURL represents the individual elements of a glob collection URL. Please refer to the 13 | // [TP1 Proposal] for additional context on each field. In summary, a glob collection URL has the following format: 14 | // 15 | // xdstp://{Authority}/{ResourceType}/{Path}{?ContextParameters} 16 | // 17 | // [TP1 Proposal]: https://github.com/cncf/xds/blob/main/proposals/TP1-xds-transport-next.md#uri-based-xds-resource-names 18 | type GlobCollectionURL struct { 19 | // The URL's authority. Optional when URL of form "xdstp:///{ResourceType}/{Path}". 20 | Authority string 21 | // The type of the resources in the collection, without the "type.googleapis.com/" prefix. 22 | ResourceType string 23 | // The collection's path, without the trailing /* 24 | Path string 25 | // Optionally, the context parameters associated with the collection, always sorted by key name. If 26 | // present, starts with "?". 27 | ContextParameters string 28 | } 29 | 30 | func (u GlobCollectionURL) String() string { 31 | return u.uri(WildcardSubscription) 32 | } 33 | 34 | func (u GlobCollectionURL) MemberURN(name string) string { 35 | return u.uri(name) 36 | } 37 | 38 | func (u GlobCollectionURL) uri(name string) string { 39 | var path string 40 | switch u.Path { 41 | case "": 42 | path = name 43 | case "/": 44 | path = "/" + name 45 | default: 46 | path = u.Path + "/" + name 47 | } 48 | 49 | return XDSTPScheme + 50 | u.Authority + "/" + 51 | u.ResourceType + "/" + 52 | path + 53 | u.ContextParameters 54 | } 55 | 56 | // NewGlobCollectionURL creates a new [GlobCollectionURL] for the given type, authority, path and 57 | // context parameters. 58 | func NewGlobCollectionURL[T proto.Message](authority, path string, contextParameters url.Values) GlobCollectionURL { 59 | return GlobCollectionURL{ 60 | Authority: authority, 61 | ResourceType: getTrimmedTypeURL[T](), 62 | Path: path, 63 | ContextParameters: contextParameters.Encode(), 64 | } 65 | } 66 | 67 | // RawNewGlobCollectionURL is the untyped equivalent of [NewGlobCollectionURL]. 68 | func RawNewGlobCollectionURL(authority, typeURL, path string, contextParameters url.Values) GlobCollectionURL { 69 | return GlobCollectionURL{ 70 | Authority: authority, 71 | ResourceType: strings.TrimPrefix(typeURL, types.APITypePrefix), 72 | Path: path, 73 | ContextParameters: contextParameters.Encode(), 74 | } 75 | } 76 | 77 | // ErrInvalidGlobCollectionURI is always returned by the various glob collection URL parsing 78 | // functions. 79 | var ErrInvalidGlobCollectionURI = errors.New("diderot: invalid glob collection URI") 80 | 81 | // TODO: the functions in this file return non-specific errors to avoid additional allocations during 82 | // cache updates, which can build up and get expensive. However this can be improved by having an 83 | // error for each of the various ways a string can be an invalid glob collection URL. 84 | 85 | // ParseGlobCollectionURL attempts to parse the given name as GlobCollectionURL, returning an error 86 | // if the given name does not represent one. See the [TP1 proposal] for additional context on the 87 | // exact definition of a glob collection. 88 | // 89 | // [TP1 proposal]: https://github.com/cncf/xds/blob/main/proposals/TP1-xds-transport-next.md#uri-based-xds-resource-names 90 | func ParseGlobCollectionURL[T proto.Message](name string) (GlobCollectionURL, error) { 91 | gcURL, resource, err := ParseGlobCollectionURN[T](name) 92 | if err != nil { 93 | return GlobCollectionURL{}, err 94 | } 95 | 96 | if resource != WildcardSubscription { 97 | // URLs must end with /* 98 | return GlobCollectionURL{}, ErrInvalidGlobCollectionURI 99 | } 100 | 101 | return gcURL, nil 102 | } 103 | 104 | // RawParseGlobCollectionURL is the untyped equivalent of [ParseGlobCollectionURL]. 105 | func RawParseGlobCollectionURL(typeURL, name string) (GlobCollectionURL, error) { 106 | gcURL, resource, err := RawParseGlobCollectionURN(typeURL, name) 107 | if err != nil { 108 | return GlobCollectionURL{}, err 109 | } 110 | 111 | if resource != WildcardSubscription { 112 | // URLs must end with /* 113 | return GlobCollectionURL{}, ErrInvalidGlobCollectionURI 114 | } 115 | 116 | return gcURL, nil 117 | } 118 | 119 | // ParseGlobCollectionURN checks if the given name is a resource URN, and returns the corresponding 120 | // GlobCollectionURL. The format of a resource URN is defined in the [TP1 proposal], and looks like 121 | // this: 122 | // 123 | // xdstp://[{authority}]/{resource type}/{id/*}?{context parameters} 124 | // 125 | // For example: 126 | // 127 | // xdstp://some-authority/envoy.config.listener.v3.Listener/foo/bar/baz 128 | // 129 | // In the above example, the URN belongs to this collection: 130 | // 131 | // xdstp://authority/envoy.config.listener.v3.Listener/foo/bar/* 132 | // 133 | // Note that in the above example, the URN does _not_ belong to the following collection: 134 | // 135 | // xdstp://authority/envoy.config.listener.v3.Listener/foo/* 136 | // 137 | // Glob collections are not recursive, and the {id/?} segment of the URN (after the type) should be 138 | // opaque, and not interpreted any further than the trailing /*. More details on this matter can be 139 | // found [here]. 140 | // 141 | // This function returns an error if the given name is not a resource URN. 142 | // 143 | // [TP1 proposal]: https://github.com/cncf/xds/blob/main/proposals/TP1-xds-transport-next.md#uri-based-xds-resource-names 144 | // [here]: https://github.com/cncf/xds/issues/91 145 | func ParseGlobCollectionURN[T proto.Message](name string) (GlobCollectionURL, string, error) { 146 | return parseXDSTPURI(getTrimmedTypeURL[T](), name) 147 | } 148 | 149 | // RawParseGlobCollectionURN is the untyped equivalent of [ParseGlobCollectionURN]. 150 | func RawParseGlobCollectionURN(typeURL, name string) (GlobCollectionURL, string, error) { 151 | return parseXDSTPURI(strings.TrimPrefix(typeURL, types.APITypePrefix), name) 152 | } 153 | 154 | func parseXDSTPURI(typeURL, resourceName string) (GlobCollectionURL, string, error) { 155 | // Skip deserializing the resource name if it doesn't start with the correct scheme 156 | if !strings.HasPrefix(resourceName, XDSTPScheme) { 157 | // doesn't start with xdstp:// 158 | return GlobCollectionURL{}, "", ErrInvalidGlobCollectionURI 159 | } 160 | 161 | parsedURL, err := url.Parse(resourceName) 162 | if err != nil { 163 | // invalid URL 164 | return GlobCollectionURL{}, "", ErrInvalidGlobCollectionURI 165 | } 166 | 167 | collectionPath, ok := strings.CutPrefix(parsedURL.EscapedPath(), "/"+typeURL+"/") 168 | if !ok { 169 | // should include expected type after authority 170 | return GlobCollectionURL{}, "", ErrInvalidGlobCollectionURI 171 | } 172 | 173 | u := GlobCollectionURL{ 174 | Authority: parsedURL.Host, 175 | ResourceType: typeURL, 176 | Path: collectionPath, 177 | } 178 | if len(parsedURL.RawQuery) > 0 { 179 | // Using .Query() to parse the query then .Encode() to re-serialize ensures the query parameters are 180 | // in the right sorted order. 181 | u.ContextParameters = "?" + parsedURL.Query().Encode() 182 | } 183 | 184 | lastSlash := strings.LastIndex(u.Path, "/") 185 | if lastSlash == -1 { 186 | // Missing path in URL 187 | return GlobCollectionURL{}, "", ErrInvalidGlobCollectionURI 188 | } 189 | 190 | resource := u.Path[lastSlash+1:] 191 | 192 | if lastSlash == 0 { 193 | u.Path = "/" 194 | } else { 195 | u.Path = u.Path[:lastSlash] 196 | } 197 | 198 | return u, resource, nil 199 | } 200 | 201 | func getTrimmedTypeURL[T proto.Message]() string { 202 | var t T 203 | return string(t.ProtoReflect().Descriptor().FullName()) 204 | } 205 | -------------------------------------------------------------------------------- /ads/glob_collection_url_test.go: -------------------------------------------------------------------------------- 1 | package ads 2 | 3 | import ( 4 | "testing" 5 | 6 | cluster "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" 7 | endpoint "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" 8 | listener "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" 9 | route "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" 10 | "github.com/envoyproxy/go-control-plane/pkg/resource/v3" 11 | "github.com/stretchr/testify/require" 12 | "google.golang.org/protobuf/proto" 13 | "google.golang.org/protobuf/types/known/wrapperspb" 14 | ) 15 | 16 | const ( 17 | resourceType = "google.protobuf.Int64Value" 18 | ) 19 | 20 | func testBadURIs(t *testing.T, parser func(string) (GlobCollectionURL, error)) { 21 | badURIs := []struct { 22 | name string 23 | resourceName string 24 | }{ 25 | { 26 | name: "empty name", 27 | resourceName: "", 28 | }, 29 | { 30 | name: "invalid prefix", 31 | resourceName: "https://foo/bar", 32 | }, 33 | { 34 | name: "wrong type", 35 | resourceName: "xdstp://auth/some.other.type/foo", 36 | }, 37 | { 38 | name: "empty id", 39 | resourceName: "xdstp://auth/google.protobuf.Int64Value", 40 | }, 41 | { 42 | name: "empty id trailing slash", 43 | resourceName: "xdstp://auth/google.protobuf.Int64Value/", 44 | }, 45 | { 46 | name: "invalid query", 47 | resourceName: "xdstp://auth/google.protobuf.Int64Value/foo?asd", 48 | }, 49 | } 50 | 51 | for _, test := range badURIs { 52 | t.Run(test.name, func(t *testing.T) { 53 | _, err := parser(test.resourceName) 54 | require.Error(t, err) 55 | }) 56 | } 57 | } 58 | 59 | func testGoodURIs(t *testing.T, id string, parser func(string) (GlobCollectionURL, error)) { 60 | tests := []struct { 61 | name string 62 | resourceName string 63 | expected GlobCollectionURL 64 | expectErr bool 65 | }{ 66 | { 67 | name: "standard", 68 | resourceName: "xdstp://auth/google.protobuf.Int64Value/foo/" + id, 69 | expected: GlobCollectionURL{ 70 | Authority: "auth", 71 | ResourceType: resourceType, 72 | Path: "foo", 73 | ContextParameters: "", 74 | }, 75 | }, 76 | { 77 | name: "empty authority", 78 | resourceName: "xdstp:///google.protobuf.Int64Value/foo/" + id, 79 | expected: GlobCollectionURL{ 80 | Authority: "", 81 | ResourceType: resourceType, 82 | Path: "foo", 83 | ContextParameters: "", 84 | }, 85 | }, 86 | { 87 | name: "nested", 88 | resourceName: "xdstp://auth/google.protobuf.Int64Value/foo/bar/baz/" + id, 89 | expected: GlobCollectionURL{ 90 | Authority: "auth", 91 | ResourceType: resourceType, 92 | Path: "foo/bar/baz", 93 | ContextParameters: "", 94 | }, 95 | }, 96 | { 97 | name: "with query", 98 | resourceName: "xdstp://auth/google.protobuf.Int64Value/foo/" + id + "?asd=123", 99 | expected: GlobCollectionURL{ 100 | Authority: "auth", 101 | ResourceType: resourceType, 102 | Path: "foo", 103 | ContextParameters: "?asd=123", 104 | }, 105 | }, 106 | { 107 | name: "with unsorted query", 108 | resourceName: "xdstp://auth/google.protobuf.Int64Value/foo/" + id + "?b=2&a=1", 109 | expected: GlobCollectionURL{ 110 | Authority: "auth", 111 | ResourceType: resourceType, 112 | Path: "foo", 113 | ContextParameters: "?a=1&b=2", 114 | }, 115 | }, 116 | { 117 | name: "empty query", 118 | resourceName: "xdstp://auth/google.protobuf.Int64Value/foo/" + id + "?", 119 | expected: GlobCollectionURL{ 120 | Authority: "auth", 121 | ResourceType: resourceType, 122 | Path: "foo", 123 | ContextParameters: "", 124 | }, 125 | }, 126 | } 127 | 128 | for _, test := range tests { 129 | t.Run(test.name, func(t *testing.T) { 130 | actual, err := parser(test.resourceName) 131 | if test.expectErr { 132 | require.Error(t, err) 133 | } else { 134 | require.NoError(t, err) 135 | require.Equal(t, test.expected, actual) 136 | } 137 | }) 138 | } 139 | } 140 | 141 | func TestParseGlobCollectionURL(t *testing.T) { 142 | genericParser := ParseGlobCollectionURL[*wrapperspb.Int64Value] 143 | typeURL := getTrimmedTypeURL[*wrapperspb.Int64Value]() 144 | rawParser := func(s string) (GlobCollectionURL, error) { 145 | return RawParseGlobCollectionURL(typeURL, s) 146 | } 147 | 148 | t.Run("bad URIs", func(t *testing.T) { 149 | testBadURIs(t, genericParser) 150 | testBadURIs(t, rawParser) 151 | }) 152 | t.Run("good URIs", func(t *testing.T) { 153 | testGoodURIs(t, WildcardSubscription, genericParser) 154 | testGoodURIs(t, WildcardSubscription, rawParser) 155 | }) 156 | t.Run("rejects URNs", func(t *testing.T) { 157 | _, err := ParseGlobCollectionURL[*wrapperspb.Int64Value]("xdstp:///" + resourceType + "/foo/bar") 158 | require.Error(t, err) 159 | _, err = RawParseGlobCollectionURL(typeURL, "xdstp:///"+resourceType+"/foo/bar") 160 | require.Error(t, err) 161 | }) 162 | } 163 | 164 | func TestParseGlobCollectionURN(t *testing.T) { 165 | genericParser := func(s string) (GlobCollectionURL, error) { 166 | gcURL, _, err := ParseGlobCollectionURN[*wrapperspb.Int64Value](s) 167 | return gcURL, err 168 | } 169 | typeURL := getTrimmedTypeURL[*wrapperspb.Int64Value]() 170 | rawParser := func(s string) (GlobCollectionURL, error) { 171 | gcURL, _, err := RawParseGlobCollectionURN(typeURL, s) 172 | return gcURL, err 173 | } 174 | 175 | t.Run("bad URIs", func(t *testing.T) { 176 | testBadURIs(t, genericParser) 177 | testBadURIs(t, rawParser) 178 | }) 179 | t.Run("good URIs", func(t *testing.T) { 180 | testGoodURIs(t, "foo", genericParser) 181 | testGoodURIs(t, "foo", rawParser) 182 | }) 183 | t.Run("handles glob collection URLs", func(t *testing.T) { 184 | gcURL, r, err := ParseGlobCollectionURN[*wrapperspb.Int64Value]("xdstp:///" + resourceType + "/foo/*") 185 | require.NoError(t, err) 186 | require.Equal(t, NewGlobCollectionURL[*wrapperspb.Int64Value]("", "foo", nil), gcURL) 187 | require.Equal(t, WildcardSubscription, r) 188 | 189 | gcURL, r, err = RawParseGlobCollectionURN(typeURL, "xdstp:///"+resourceType+"/foo/*") 190 | require.NoError(t, err) 191 | require.Equal(t, RawNewGlobCollectionURL("", typeURL, "foo", nil), gcURL) 192 | require.Equal(t, WildcardSubscription, r) 193 | }) 194 | } 195 | 196 | func TestGetTrimmedTypeURL(t *testing.T) { 197 | check := func(expected, actualTrimmed string) { 198 | require.Equal(t, expected, resource.APITypePrefix+actualTrimmed) 199 | } 200 | check(resource.ListenerType, getTrimmedTypeURL[*listener.Listener]()) 201 | check(resource.EndpointType, getTrimmedTypeURL[*endpoint.ClusterLoadAssignment]()) 202 | check(resource.ClusterType, getTrimmedTypeURL[*cluster.Cluster]()) 203 | check(resource.RouteType, getTrimmedTypeURL[*route.RouteConfiguration]()) 204 | } 205 | 206 | func BenchmarkGetTrimmedTypeURL(b *testing.B) { 207 | benchmarkGetTrimmedTypeURL[*wrapperspb.Int64Value](b) 208 | benchmarkGetTrimmedTypeURL[*cluster.Cluster](b) 209 | } 210 | 211 | func benchmarkGetTrimmedTypeURL[T proto.Message](b *testing.B) { 212 | b.Run(getTrimmedTypeURL[T](), func(b *testing.B) { 213 | var url string 214 | for range b.N { 215 | url = getTrimmedTypeURL[T]() 216 | } 217 | require.NotEmpty(b, url) 218 | }) 219 | } 220 | 221 | // This benchmark validates that using reflection on the protobuf type does not incur any unexpected 222 | // costs. The most expensive operation is currently [url.Parse], which can't easily be removed. It is 223 | // critical that this operation remain as inexpensive as possible because it is called repeatedly in 224 | // the cache (for each insertion or deletion). Current results: 225 | // 226 | // goos: darwin 227 | // goarch: amd64 228 | // pkg: github.com/linkedin/diderot/ads 229 | // cpu: VirtualApple @ 2.50GHz 230 | // │ results │ 231 | // │ sec/op │ 232 | // ParseGlobCollectionURN/xdstp://foo/google.protobuf.Int64Value/bar/*/generic-8 624.2n ± 3% 233 | // ParseGlobCollectionURN/xdstp://foo/google.protobuf.Int64Value/bar/*/raw-8 610.4n ± 0% 234 | // ParseGlobCollectionURN/xdstp://foo/google.protobuf.Int64Value/bar/*/raw_with_prefix-8 609.6n ± 0% 235 | // ParseGlobCollectionURN/xdstp://foo/envoy.config.cluster.v3.Cluster/bar/*/generic-8 697.1n ± 5% 236 | // ParseGlobCollectionURN/xdstp://foo/envoy.config.cluster.v3.Cluster/bar/*/raw-8 692.4n ± 1% 237 | // ParseGlobCollectionURN/xdstp://foo/envoy.config.cluster.v3.Cluster/bar/*/raw_with_prefix-8 691.5n ± 0% 238 | // geomean 653.0n 239 | // 240 | // │ results │ 241 | // │ B/op │ 242 | // ParseGlobCollectionURN/xdstp://foo/google.protobuf.Int64Value/bar/*/generic-8 192.0 ± 0% 243 | // ParseGlobCollectionURN/xdstp://foo/google.protobuf.Int64Value/bar/*/raw-8 192.0 ± 0% 244 | // ParseGlobCollectionURN/xdstp://foo/google.protobuf.Int64Value/bar/*/raw_with_prefix-8 192.0 ± 0% 245 | // ParseGlobCollectionURN/xdstp://foo/envoy.config.cluster.v3.Cluster/bar/*/generic-8 240.0 ± 0% 246 | // ParseGlobCollectionURN/xdstp://foo/envoy.config.cluster.v3.Cluster/bar/*/raw-8 240.0 ± 0% 247 | // ParseGlobCollectionURN/xdstp://foo/envoy.config.cluster.v3.Cluster/bar/*/raw_with_prefix-8 240.0 ± 0% 248 | // geomean 214.7 249 | // 250 | // │ results │ 251 | // │ allocs/op │ 252 | // ParseGlobCollectionURN/xdstp://foo/google.protobuf.Int64Value/bar/*/generic-8 2.000 ± 0% 253 | // ParseGlobCollectionURN/xdstp://foo/google.protobuf.Int64Value/bar/*/raw-8 2.000 ± 0% 254 | // ParseGlobCollectionURN/xdstp://foo/google.protobuf.Int64Value/bar/*/raw_with_prefix-8 2.000 ± 0% 255 | // ParseGlobCollectionURN/xdstp://foo/envoy.config.cluster.v3.Cluster/bar/*/generic-8 3.000 ± 0% 256 | // ParseGlobCollectionURN/xdstp://foo/envoy.config.cluster.v3.Cluster/bar/*/raw-8 3.000 ± 0% 257 | // ParseGlobCollectionURN/xdstp://foo/envoy.config.cluster.v3.Cluster/bar/*/raw_with_prefix-8 3.000 ± 0% 258 | // geomean 2.449 259 | func BenchmarkParseGlobCollectionURN(b *testing.B) { 260 | benchmarkParseGlobCollectionURN[*wrapperspb.Int64Value](b) 261 | benchmarkParseGlobCollectionURN[*cluster.Cluster](b) 262 | } 263 | 264 | func benchmarkParseGlobCollectionURN[T proto.Message](b *testing.B) { 265 | expectedURL := NewGlobCollectionURL[T]("foo", "bar", nil) 266 | url := expectedURL.String() 267 | 268 | run := func(b *testing.B, f func() (GlobCollectionURL, string, error)) { 269 | var actualURL GlobCollectionURL 270 | var err error 271 | for range b.N { 272 | actualURL, _, err = f() 273 | if err != nil { 274 | b.Fatal(err) 275 | } 276 | } 277 | require.Equal(b, expectedURL, actualURL) 278 | } 279 | 280 | b.Run(url, func(b *testing.B) { 281 | b.Run("generic", func(b *testing.B) { 282 | run(b, func() (GlobCollectionURL, string, error) { 283 | return ParseGlobCollectionURN[T](url) 284 | }) 285 | }) 286 | 287 | typeURL := getTrimmedTypeURL[T]() 288 | b.Run("raw", func(b *testing.B) { 289 | run(b, func() (GlobCollectionURL, string, error) { 290 | return RawParseGlobCollectionURN(typeURL, url) 291 | }) 292 | }) 293 | 294 | prefixedTypeURL := getTrimmedTypeURL[T]() 295 | b.Run("raw with prefix", func(b *testing.B) { 296 | run(b, func() (GlobCollectionURL, string, error) { 297 | return RawParseGlobCollectionURN(prefixedTypeURL, url) 298 | }) 299 | }) 300 | }) 301 | } 302 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package diderot 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "iter" 7 | "log/slog" 8 | "slices" 9 | "sync" 10 | "time" 11 | 12 | discoveryv3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" 13 | "github.com/linkedin/diderot/ads" 14 | internal "github.com/linkedin/diderot/internal/client" 15 | "github.com/linkedin/diderot/internal/utils" 16 | "google.golang.org/grpc" 17 | "google.golang.org/grpc/codes" 18 | "google.golang.org/grpc/status" 19 | "google.golang.org/protobuf/proto" 20 | ) 21 | 22 | type ADSClientOption func(*options) 23 | 24 | const ( 25 | defaultInitialReconnectBackoff = 100 * time.Millisecond 26 | defaultMaxReconnectBackoff = 2 * time.Minute 27 | defaultResponseChunkingSupported = true 28 | ) 29 | 30 | // NewADSClient creates a new [*ADSClient] with the given options. To stop the client, close the 31 | // backing [grpc.ClientConn]. 32 | func NewADSClient(conn grpc.ClientConnInterface, node *ads.Node, opts ...ADSClientOption) *ADSClient { 33 | c := &ADSClient{ 34 | conn: conn, 35 | node: node, 36 | newSubscription: make(chan struct{}, 1), 37 | handlers: make(map[string]internal.RawResourceHandler), 38 | options: options{ 39 | initialReconnectBackoff: defaultInitialReconnectBackoff, 40 | maxReconnectBackoff: defaultMaxReconnectBackoff, 41 | responseChunkingSupported: defaultResponseChunkingSupported, 42 | }, 43 | } 44 | 45 | for _, opt := range opts { 46 | opt(&c.options) 47 | } 48 | 49 | go c.loop() 50 | 51 | return c 52 | } 53 | 54 | type options struct { 55 | initialReconnectBackoff time.Duration 56 | maxReconnectBackoff time.Duration 57 | responseChunkingSupported bool 58 | } 59 | 60 | // WithReconnectBackoff provides backoff configuration when reconnecting to the xDS backend after a 61 | // connection failure. The default settings are 100ms and 2m for the initial and max backoff 62 | // respectively. 63 | func WithReconnectBackoff(initialBackoff, maxBackoff time.Duration) ADSClientOption { 64 | return func(o *options) { 65 | o.initialReconnectBackoff = initialBackoff 66 | o.maxReconnectBackoff = maxBackoff 67 | } 68 | } 69 | 70 | // WithResponseChunkingSupported changes whether response chunking should be supported (see 71 | // [ads.ParseRemainingChunksFromNonce] for additional details). This feature is only provided by the 72 | // [ADSServer] implemented in this package. This enabled by default. 73 | func WithResponseChunkingSupported(supported bool) ADSClientOption { 74 | return func(o *options) { 75 | o.responseChunkingSupported = supported 76 | } 77 | } 78 | 79 | // An ADSClient is a client that implements the xDS protocol, and can therefore be used to talk to 80 | // any xDS backend. Use the [Watch], [WatchGlob] and [WatchWildcard] to subscribe to resources. 81 | type ADSClient struct { 82 | options 83 | node *ads.Node 84 | conn grpc.ClientConnInterface 85 | 86 | newSubscription chan struct{} 87 | 88 | lock sync.Mutex 89 | handlers map[string]internal.RawResourceHandler 90 | } 91 | 92 | // A Watcher is used to receive updates from the xDS backend using an [ADSClient]. It is passed into 93 | // the various [Watch] methods in this package. Note that it is imperative that implementations be 94 | // hashable as it will be stored as the key to a map (unhashable types include slices and functions). 95 | type Watcher[T proto.Message] interface { 96 | // Notify is invoked whenever a response is processed. The given sequence will iterate over all the 97 | // resources in the response, with a nil resource indicating a deletion. Implementations should 98 | // return an error if any resource is invalid, and this error will be propagated as a NACK to the xDS 99 | // backend. 100 | Notify(resources iter.Seq2[string, *ads.Resource[T]]) error 101 | } 102 | 103 | // Watch registers the given watcher in the given client, triggering a subscription (if necessary) 104 | // for the given resource name such that the [Watcher] will be notified whenever the resource is 105 | // updated. If a resource is already known (for example from a previous existing subscription), the 106 | // watcher will be immediately notified. Glob or wildcard subscriptions are supported, and 107 | // [Watcher.Notify] will be invoked with a sequence that iterates over all the updated resources. 108 | func Watch[T proto.Message](c *ADSClient, name string, watcher Watcher[T]) { 109 | if getResourceHandler[T](c).AddWatcher(name, watcher) { 110 | c.notifyNewSubscription() 111 | } 112 | } 113 | 114 | // getResourceHandler gets or initializes the [internal.ResourceHandler] for the specified type in 115 | // the given client. 116 | func getResourceHandler[T proto.Message](c *ADSClient) *internal.ResourceHandler[T] { 117 | c.lock.Lock() 118 | defer c.lock.Unlock() 119 | 120 | typeURL := utils.GetTypeURL[T]() 121 | if hAny, ok := c.handlers[typeURL]; !ok { 122 | h := internal.NewResourceHandler[T]() 123 | c.handlers[typeURL] = h 124 | return h 125 | } else { 126 | return hAny.(*internal.ResourceHandler[T]) 127 | } 128 | } 129 | 130 | func (c *ADSClient) getResourceHandler(typeURL string) (internal.RawResourceHandler, bool) { 131 | c.lock.Lock() 132 | defer c.lock.Unlock() 133 | h, ok := c.handlers[typeURL] 134 | return h, ok 135 | } 136 | 137 | // notifyNewSubscription signals to the subscription loop that a new subscription was added. 138 | func (c *ADSClient) notifyNewSubscription() { 139 | select { 140 | case c.newSubscription <- struct{}{}: 141 | default: 142 | } 143 | } 144 | 145 | // This is a type alias for the set of resources the client is subscribed to. The key is the typeURL 146 | // and the value is the set of resource names subscribed to within that type. 147 | type subscriptionSet map[string]utils.Set[string] 148 | 149 | // getPendingSubscriptions iterates over all the subscriptions returned by invoking 150 | // [internal.ResourceHandler.AllSubscriptions] on all registered resource handlers, and compares it 151 | // against the given set of already registered subscriptions. If any are missing, they are added to 152 | // the returned subscription set after being added to the given set. This means that repeated 153 | // invocations of this method will return an empty set if no new subscriptions are added in between. 154 | func (c *ADSClient) getPendingSubscriptions(registeredSubscriptions subscriptionSet) subscriptionSet { 155 | c.lock.Lock() 156 | defer c.lock.Unlock() 157 | 158 | pendingSubscriptions := make(subscriptionSet) 159 | add := func(typeURL string, name string) { 160 | registered := internal.GetNestedMap(registeredSubscriptions, typeURL) 161 | if !registered.Contains(name) { 162 | registered.Add(name) 163 | internal.GetNestedMap(pendingSubscriptions, typeURL).Add(name) 164 | } 165 | } 166 | 167 | for t, handler := range c.handlers { 168 | for k := range handler.AllSubscriptions() { 169 | add(t, k) 170 | } 171 | } 172 | 173 | return pendingSubscriptions 174 | } 175 | 176 | // loop simply calls newStream and subscriptionLoop forever, until the underlying gRPC connection is 177 | // closed. 178 | func (c *ADSClient) loop() { 179 | for { 180 | // See documentation on subscriptionLoop. It returns when the stream ends, so a fresh stream needs to 181 | // be created every time. 182 | stream, responses, err := c.newStream() 183 | if err != nil { 184 | return 185 | } 186 | 187 | err = c.subscriptionLoop(stream, responses) 188 | slog.WarnContext(stream.Context(), "Restarting ADS stream", "err", err) 189 | } 190 | } 191 | 192 | // subscriptionLoop is the critical logic loop for the client. It polls the given responses channel, 193 | // notifying watchers when new responses come in. Each slice returned by the responses channel is 194 | // expected to contain responses that are all for the same typeURL. In most cases, the slice will 195 | // only have one response in it, but if response chunking is supported, the slice will have all the 196 | // response chunks in it. It also waits for any new subscriptions to be registered, and sends them to 197 | // the server. This returns whenever the stream ends. 198 | func (c *ADSClient) subscriptionLoop(stream deltaClient, responsesCh <-chan []*ads.DeltaDiscoveryResponse) error { 199 | registeredSubscriptions := make(subscriptionSet) 200 | 201 | sendPendingSubscriptions := func() error { 202 | pending := c.getPendingSubscriptions(registeredSubscriptions) 203 | if len(pending) == 0 { 204 | return nil 205 | } 206 | 207 | slog.InfoContext(stream.Context(), "Subscribing to resources", "subscriptions", pending) 208 | for t, subs := range pending { 209 | err := stream.Send(&ads.DeltaDiscoveryRequest{ 210 | Node: c.node, 211 | TypeUrl: t, 212 | ResourceNamesSubscribe: slices.Collect(subs.Values()), 213 | }) 214 | if err != nil { 215 | return err 216 | } 217 | } 218 | return nil 219 | } 220 | 221 | isFirst := true 222 | for { 223 | err := sendPendingSubscriptions() 224 | if err != nil { 225 | return err 226 | } 227 | 228 | select { 229 | case <-c.newSubscription: 230 | err := sendPendingSubscriptions() 231 | if err != nil { 232 | return err 233 | } 234 | case responses := <-responsesCh: 235 | h, ok := c.getResourceHandler(responses[0].TypeUrl) 236 | if !ok { 237 | for _, res := range responses { 238 | err := c.sendACKOrNACK( 239 | stream, 240 | res, 241 | fmt.Errorf("received response with unknown type: %q", res.TypeUrl), 242 | ) 243 | if err != nil { 244 | slog.WarnContext(stream.Context(), "ADS stream closed", "err", err) 245 | return err 246 | } 247 | } 248 | continue 249 | } 250 | 251 | // Always ACK all but the last response. Errors will only be reported back to the server once all 252 | // chunks are processed. 253 | for _, res := range responses[:len(responses)-1] { 254 | err := c.sendACKOrNACK(stream, res, nil) 255 | if err != nil { 256 | return err 257 | } 258 | } 259 | 260 | handlerErr := h.HandleResponses(isFirst, responses) 261 | isFirst = false 262 | if err = c.sendACKOrNACK(stream, responses[len(responses)-1], handlerErr); err != nil { 263 | return err 264 | } 265 | case <-stream.Context().Done(): 266 | return stream.Context().Err() 267 | } 268 | } 269 | } 270 | 271 | // sendACKOrNACK will send an ACK or NACK (depending on the given error) for the given response. 272 | func (c *ADSClient) sendACKOrNACK(stream deltaClient, res *ads.DeltaDiscoveryResponse, err error) error { 273 | req := &ads.DeltaDiscoveryRequest{ 274 | Node: c.node, 275 | TypeUrl: res.TypeUrl, 276 | ResponseNonce: res.Nonce, 277 | } 278 | if err != nil { 279 | req.ErrorDetail = status.New(codes.InvalidArgument, err.Error()).Proto() 280 | slog.WarnContext(stream.Context(), "NACKing response", "res", res, "err", err) 281 | } else { 282 | slog.DebugContext(stream.Context(), "ACKing response", "res", res) 283 | } 284 | return stream.Send(req) 285 | } 286 | 287 | // newStream acquires a fresh stream from getDeltaClient and kicks off a goroutine that will read all 288 | // responses from the stream, writing them to the returned channel. The goroutine will exit when the 289 | // stream ends. 290 | func (c *ADSClient) newStream() (deltaClient, <-chan []*ads.DeltaDiscoveryResponse, error) { 291 | stream, err := c.getDeltaClient() 292 | if err != nil { 293 | return nil, nil, err 294 | } 295 | 296 | responses := make(chan []*ads.DeltaDiscoveryResponse) 297 | go func() { 298 | chunkedResponses := make(map[string][]*ads.DeltaDiscoveryResponse) 299 | 300 | for { 301 | res, err := stream.Recv() 302 | if err != nil { 303 | slog.WarnContext(stream.Context(), "ADS stream closed", "err", err) 304 | return 305 | } 306 | 307 | slog.Debug("Response received", "res", res) 308 | 309 | var resSlice []*ads.DeltaDiscoveryResponse 310 | 311 | if c.responseChunkingSupported { 312 | resSlice = chunkedResponses[res.TypeUrl] 313 | resSlice = append(resSlice, res) 314 | chunkedResponses[res.TypeUrl] = resSlice 315 | if remainingChunks, _ := ads.ParseRemainingChunksFromNonce(res.Nonce); remainingChunks != 0 { 316 | continue 317 | } else { 318 | delete(chunkedResponses, res.TypeUrl) 319 | } 320 | } else { 321 | resSlice = []*ads.DeltaDiscoveryResponse{res} 322 | } 323 | 324 | select { 325 | case responses <- resSlice: 326 | case <-stream.Context().Done(): 327 | slog.WarnContext(stream.Context(), "ADS stream closed", "err", stream.Context().Err()) 328 | return 329 | } 330 | } 331 | }() 332 | 333 | return stream, responses, nil 334 | } 335 | 336 | type deltaClient interface { 337 | Send(*ads.DeltaDiscoveryRequest) error 338 | Recv() (*ads.DeltaDiscoveryResponse, error) 339 | Context() context.Context 340 | } 341 | 342 | // getDeltaClient attempts to reconnect to the ADS Server until it either successfully establishes a 343 | // stream, or the underlying gRPC connection is explicitly closed, signaling a shutdown. 344 | func (c *ADSClient) getDeltaClient() (deltaClient, error) { 345 | backoff := c.initialReconnectBackoff 346 | for { 347 | delta, err := discoveryv3.NewAggregatedDiscoveryServiceClient(c.conn). 348 | DeltaAggregatedResources(context.Background()) 349 | if err != nil { 350 | // This only occurs if c.conn was closed since context.Background() is used to create the stream. 351 | if st := status.Convert(err); st.Code() == codes.Canceled { 352 | return nil, err 353 | } 354 | 355 | slog.Warn("Failed to create Delta stream, retrying", "backoff", backoff, "err", err) 356 | time.Sleep(backoff) 357 | backoff = min(backoff*2, c.maxReconnectBackoff) 358 | continue 359 | } 360 | return delta, nil 361 | } 362 | } 363 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | package diderot 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "iter" 8 | "maps" 9 | "sync" 10 | "sync/atomic" 11 | "testing" 12 | "time" 13 | 14 | discovery "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" 15 | "github.com/linkedin/diderot/ads" 16 | "github.com/linkedin/diderot/internal/utils" 17 | "github.com/linkedin/diderot/testutils" 18 | "github.com/stretchr/testify/require" 19 | "google.golang.org/grpc" 20 | "google.golang.org/grpc/codes" 21 | "google.golang.org/grpc/metadata" 22 | "google.golang.org/grpc/status" 23 | "google.golang.org/protobuf/proto" 24 | "google.golang.org/protobuf/types/known/durationpb" 25 | "google.golang.org/protobuf/types/known/timestamppb" 26 | ) 27 | 28 | type Timestamp = timestamppb.Timestamp 29 | 30 | var Now = timestamppb.Now 31 | 32 | func newClient(t *testing.T, chunkingSupport bool) (conn *mockConn, client *ADSClient) { 33 | conn = &mockConn{ 34 | t: t, 35 | streams: make(chan *mockStream), 36 | } 37 | client = NewADSClient( 38 | conn, 39 | &ads.Node{Id: "test"}, 40 | WithReconnectBackoff(0, 0), 41 | WithResponseChunkingSupported(chunkingSupport), 42 | ) 43 | return conn, client 44 | } 45 | 46 | func TestADSClient(t *testing.T) { 47 | tests := []struct { 48 | chunkingEnabled bool 49 | }{ 50 | { 51 | chunkingEnabled: false, 52 | }, 53 | { 54 | chunkingEnabled: true, 55 | }, 56 | } 57 | for _, test := range tests { 58 | t.Run(fmt.Sprintf("flow/chunkingEnabled=%v", test.chunkingEnabled), func(t *testing.T) { 59 | testADSClientFlow(t, test.chunkingEnabled) 60 | }) 61 | } 62 | 63 | // Check that the client NACKs a response for a type that was never subscribed to. 64 | t.Run("invalid type", func(t *testing.T) { 65 | conn, client := newClient(t, defaultResponseChunkingSupported) 66 | 67 | Watch(client, ads.WildcardSubscription, &FuncWatcher[*Timestamp]{ 68 | notify: func(resources iter.Seq2[string, *ads.Resource[*Timestamp]]) error { 69 | require.FailNow(t, "Should not be called") 70 | return nil 71 | }, 72 | }) 73 | 74 | ms := conn.accept() 75 | 76 | ms.expectSubscriptions(ads.WildcardSubscription) 77 | 78 | nonce := respond[*durationpb.Duration](ms, []*ads.Resource[*durationpb.Duration]{ 79 | ads.NewResource[*durationpb.Duration]("test", "0", durationpb.New(time.Minute)), 80 | }, nil, 0) 81 | 82 | expectNACK[*durationpb.Duration](ms, nonce, codes.InvalidArgument, utils.GetTypeURL[*durationpb.Duration]()) 83 | }) 84 | 85 | // Check that the client NACKs a response if a watcher returns an error. 86 | t.Run("NACKs", func(t *testing.T) { 87 | conn, client := newClient(t, defaultResponseChunkingSupported) 88 | 89 | Watch(client, ads.WildcardSubscription, &FuncWatcher[*Timestamp]{ 90 | notify: func(resources iter.Seq2[string, *ads.Resource[*Timestamp]]) error { 91 | return io.EOF 92 | }, 93 | }) 94 | 95 | ms := conn.accept() 96 | 97 | ms.expectSubscriptions(ads.WildcardSubscription) 98 | 99 | nonce := ms.respondUpdates(0, ads.NewResource[*Timestamp]("foo", "0", Now())) 100 | 101 | expectNACK[*Timestamp](ms, nonce, codes.InvalidArgument, io.EOF.Error()) 102 | }) 103 | 104 | // Check that if the server responds with an unknown resource, it is skipped and reported, but other 105 | // valid resources in the response are still parsed. 106 | t.Run("unknown resource", func(t *testing.T) { 107 | conn, client := newClient(t, defaultResponseChunkingSupported) 108 | 109 | fooH := make(testutils.ChanSubscriptionHandler[*Timestamp], 1) 110 | foo := ads.NewResource[*Timestamp]("foo", "0", Now()) 111 | Watch(client, foo.Name, ChanWatcher[*Timestamp](fooH)) 112 | 113 | ms := conn.accept() 114 | 115 | ms.expectSubscriptions(foo.Name) 116 | 117 | nonce := ms.respondUpdates(0, foo, ads.NewResource("bar", "0", Now())) 118 | 119 | expectNACK[*Timestamp](ms, nonce, codes.InvalidArgument, "bar") 120 | fooH.WaitForUpdate(t, foo) 121 | }) 122 | } 123 | 124 | func testADSClientFlow(t *testing.T, chunkingEnabled bool) { 125 | conn, client := newClient(t, chunkingEnabled) 126 | 127 | fooH := make(testutils.ChanSubscriptionHandler[*Timestamp], 1) 128 | foo := ads.NewResource[*Timestamp]("foo", "0", Now()) 129 | Watch(client, foo.Name, ChanWatcher[*Timestamp](fooH)) 130 | 131 | // The stream has not yet been established, no updates should be received. 132 | checkNoUpdate(t, fooH) 133 | 134 | // Accept a new stream 135 | ms := conn.accept() 136 | 137 | // The resource does not initially exist, the first update should be a deletion. 138 | ms.expectSubscriptions(foo.Name) 139 | nonce := ms.respondDeletes(0, foo.Name) 140 | fooH.WaitForDelete(t, foo.Name) 141 | ms.expectACK(nonce) 142 | 143 | // Set foo, and wait for the creation update 144 | nonce = ms.respondUpdates(0, foo) 145 | fooH.WaitForUpdate(t, foo) 146 | ms.expectACK(nonce) 147 | 148 | ms.cancel() 149 | ms = conn.accept() 150 | // Closing and reopening the stream makes the client reconnect, but since foo hasn't changed, nothing 151 | // should happen. 152 | ms.expectSubscriptions(foo.Name) 153 | nonce = ms.respondUpdates(0, foo) 154 | checkNoUpdate(t, fooH) 155 | ms.expectACK(nonce) 156 | 157 | // Disconnect the client, foo is updated during disconnect so expect a notification 158 | ms.cancel() 159 | foo = ads.NewResource(foo.Name, "1", Now()) 160 | ms = conn.accept() 161 | ms.expectSubscriptions(foo.Name) 162 | nonce = ms.respondUpdates(0, foo) 163 | fooH.WaitForUpdate(t, foo) 164 | ms.expectACK(nonce) 165 | 166 | wildcardH := make(testutils.ChanSubscriptionHandler[*Timestamp], 2) 167 | var wildcardExpectedCount atomic.Int32 168 | Watch(client, ads.WildcardSubscription, &FuncWatcher[*Timestamp]{ 169 | notify: func(resources iter.Seq2[string, *ads.Resource[*Timestamp]]) error { 170 | require.Len(t, maps.Collect(resources), int(wildcardExpectedCount.Load())) 171 | for name, resource := range resources { 172 | wildcardH <- testutils.Notification[*Timestamp]{ 173 | Name: name, 174 | Resource: resource, 175 | } 176 | } 177 | return nil 178 | }, 179 | }) 180 | 181 | ms.expectSubscriptions(ads.WildcardSubscription) 182 | bar := ads.NewResource[*Timestamp]("bar", "0", Now()) 183 | if chunkingEnabled { 184 | // Respond in multiple chunks, to test that those are handled correctly 185 | chunkNonce1 := ms.respondUpdates(1, foo) 186 | // No update expected after first chunk 187 | checkNoUpdate(t, wildcardH) 188 | // As soon as the second chunk arrives, an update is expected, so update the expected count before 189 | // sending the response. 190 | wildcardExpectedCount.Store(2) 191 | chunkNonce2 := ms.respondUpdates(0, bar) 192 | ms.expectACK(chunkNonce1) 193 | ms.expectACK(chunkNonce2) 194 | } else { 195 | wildcardExpectedCount.Store(2) 196 | nonce = ms.respondUpdates(0, foo, bar) 197 | ms.expectACK(nonce) 198 | } 199 | 200 | // Expect a notification for foo and bar for wildcardH, but since fooH has already seen that version 201 | // of foo, it should not receive an update. 202 | wildcardH.WaitForNotifications(t, 203 | testutils.ExpectUpdate(foo), 204 | testutils.ExpectUpdate(bar), 205 | ) 206 | checkNoUpdate(t, fooH) 207 | 208 | // Clear foo, expect a deletion on fooH and the wildcard subscriber. 209 | wildcardExpectedCount.Store(1) 210 | nonce = ms.respondDeletes(0, foo.Name) 211 | ms.expectACK(nonce) 212 | fooH.WaitForDelete(t, foo.Name) 213 | wildcardH.WaitForDelete(t, foo.Name) 214 | 215 | // Create new glob collection entries, which the wildcard subscriber should receive. 216 | wildcardExpectedCount.Store(1) 217 | gcURL := ads.NewGlobCollectionURL[*Timestamp]("", "collection", nil) 218 | fooGlob := ads.NewResource(gcURL.MemberURN("foo"), "0", Now()) 219 | nonce = ms.respondUpdates(0, fooGlob) 220 | ms.expectACK(nonce) 221 | wildcardH.WaitForNotifications(t, testutils.ExpectUpdate(fooGlob)) 222 | 223 | barGlob := ads.NewResource(gcURL.MemberURN("bar"), "0", Now()) 224 | nonce = ms.respondUpdates(0, barGlob) 225 | ms.expectACK(nonce) 226 | wildcardH.WaitForNotifications(t, testutils.ExpectUpdate(barGlob)) 227 | 228 | // Subscribe to the glob collection. expecting an update for fooGlob and barGlob. 229 | globH := make(testutils.ChanSubscriptionHandler[*Timestamp], 2) 230 | var globExpectedCount atomic.Int32 231 | // Because the resources are already known thanks to the wildcard, this expects a notification 232 | // immediately, before the subscription is even sent. 233 | Watch(client, gcURL.String(), &FuncWatcher[*Timestamp]{ 234 | notify: func(resources iter.Seq2[string, *ads.Resource[*Timestamp]]) error { 235 | require.Len(t, maps.Collect(resources), int(globExpectedCount.Load())) 236 | for name, resource := range resources { 237 | globH <- testutils.Notification[*Timestamp]{ 238 | Name: name, 239 | Resource: resource, 240 | } 241 | } 242 | return nil 243 | }, 244 | }) 245 | ms.expectSubscriptions(gcURL.String()) 246 | globExpectedCount.Store(2) 247 | nonce = ms.respondUpdates(0, fooGlob, barGlob) 248 | ms.expectACK(nonce) 249 | globH.WaitForNotifications(t, 250 | testutils.ExpectUpdate(fooGlob), 251 | testutils.ExpectUpdate(barGlob), 252 | ) 253 | globExpectedCount.Store(0) 254 | 255 | // Clear fooGlob, expect deletions for it. 256 | wildcardExpectedCount.Store(1) 257 | globExpectedCount.Store(1) 258 | nonce = ms.respondDeletes(0, fooGlob.Name) 259 | ms.expectACK(nonce) 260 | wildcardH.WaitForDelete(t, fooGlob.Name) 261 | globH.WaitForDelete(t, fooGlob.Name) 262 | 263 | // Disconnect the client and clear the collection during the disconnect. When the client reconnects, 264 | // because it explicitly subscribes to the glob collection it will receive a deletion notification 265 | // for the entire collection, but not for barGlob explicitly, as the server has forgotten that it 266 | // exists. The client must figure out that barGlob has disappeared while it was disconnected. The 267 | // same is true for the wildcard subscription: the client will not receive an explicit notification 268 | // that barGlob has disappeared. 269 | ms.cancel() 270 | ms = conn.accept() 271 | ms.expectSubscriptions(foo.Name, ads.WildcardSubscription, gcURL.String()) 272 | 273 | nonce = respond[*Timestamp]( 274 | ms, 275 | // The only remaining resource is bar 276 | []*ads.Resource[*Timestamp]{bar}, 277 | // These are explicitly subscribed to but do not exist, so explicit removals are expected 278 | []string{foo.Name, gcURL.String()}, 279 | 0, 280 | ) 281 | ms.respondUpdates(0, bar) 282 | ms.expectACK(nonce) 283 | globH.WaitForDelete(t, barGlob.Name) 284 | wildcardH.WaitForDelete(t, barGlob.Name) 285 | 286 | // This is an edge case, but bar is known because of the wildcard subscription. Therefore, even while 287 | // the client is offline, subscribing to bar should deliver the notification. 288 | ms.cancel() 289 | barH := make(testutils.ChanSubscriptionHandler[*Timestamp], 1) 290 | Watch(client, bar.Name, ChanWatcher[*Timestamp](barH)) 291 | barH.WaitForUpdate(t, bar) 292 | ms = conn.accept() 293 | // There should be an explicit subscription sent, but because bar is already known, no further 294 | // updates should be received. 295 | ms.expectSubscriptions(foo.Name, bar.Name, ads.WildcardSubscription, gcURL.String()) 296 | nonce = ms.respondUpdates(0, bar) 297 | ms.expectACK(nonce) 298 | checkNoUpdate(t, barH) 299 | 300 | // Delete bar, the final resource 301 | nonce = ms.respondDeletes(0, bar.Name) 302 | ms.expectACK(nonce) 303 | 304 | barH.WaitForDelete(t, bar.Name) 305 | wildcardH.WaitForDelete(t, bar.Name) 306 | 307 | // Disconnect again to test what happens when Watch is called while offline for glob and wildcards. 308 | ms.cancel() 309 | allResources := new(map[string]*ads.Resource[*Timestamp]) 310 | Watch(client, ads.WildcardSubscription, OnceWatcher(allResources)) 311 | // This should be immediately ready, as data has been received and far as the client knows, there are 312 | // no resources. 313 | require.NotNil(t, *allResources) 314 | require.Empty(t, *allResources) 315 | 316 | // Same behavior expected for glob 317 | allGlobResource := new(map[string]*ads.Resource[*Timestamp]) 318 | Watch(client, gcURL.String(), OnceWatcher(allGlobResource)) 319 | require.NotNil(t, *allGlobResource) 320 | require.Empty(t, *allGlobResource) 321 | } 322 | 323 | type FuncWatcher[T proto.Message] struct { 324 | notify func(resources iter.Seq2[string, *ads.Resource[T]]) error 325 | } 326 | 327 | func (f FuncWatcher[T]) Notify(resources iter.Seq2[string, *ads.Resource[T]]) error { 328 | return f.notify(resources) 329 | } 330 | 331 | type ChanWatcher[T proto.Message] testutils.ChanSubscriptionHandler[T] 332 | 333 | func (c ChanWatcher[T]) Notify(resources iter.Seq2[string, *ads.Resource[T]]) error { 334 | for name, resource := range resources { 335 | testutils.ChanSubscriptionHandler[T](c).Notify(name, resource, ads.SubscriptionMetadata{}) 336 | } 337 | return nil 338 | } 339 | 340 | func checkNoUpdate[T proto.Message](t *testing.T, h testutils.ChanSubscriptionHandler[T]) { 341 | select { 342 | case n := <-h: 343 | require.FailNow(t, "handler should not receive any messages", n) 344 | case <-time.After(500 * time.Millisecond): 345 | } 346 | } 347 | 348 | func OnceWatcher[T proto.Message](m *map[string]*ads.Resource[T]) Watcher[T] { 349 | var once sync.Once 350 | return &FuncWatcher[T]{notify: func(resources iter.Seq2[string, *ads.Resource[T]]) error { 351 | once.Do(func() { 352 | *m = maps.Collect(resources) 353 | }) 354 | return nil 355 | }} 356 | } 357 | 358 | type mockConn struct { 359 | t *testing.T 360 | streams chan *mockStream 361 | } 362 | 363 | func (mc *mockConn) Invoke(context.Context, string, any, any, ...grpc.CallOption) error { 364 | mc.t.Fatalf("Not supported") 365 | return nil 366 | } 367 | 368 | func (mc *mockConn) NewStream( 369 | _ context.Context, 370 | _ *grpc.StreamDesc, 371 | method string, 372 | _ ...grpc.CallOption, 373 | ) (grpc.ClientStream, error) { 374 | require.Equal(mc.t, discovery.AggregatedDiscoveryService_DeltaAggregatedResources_FullMethodName, method) 375 | return <-mc.streams, nil 376 | } 377 | 378 | func (mc *mockConn) accept() *mockStream { 379 | s := &mockStream{ 380 | conn: mc, 381 | requests: make(chan *ads.DeltaDiscoveryRequest), 382 | responses: make(chan *ads.DeltaDiscoveryResponse), 383 | } 384 | s.ctx, s.cancel = context.WithCancel(context.Background()) 385 | mc.streams <- s 386 | return s 387 | } 388 | 389 | type mockStream struct { 390 | conn *mockConn 391 | ctx context.Context 392 | cancel context.CancelFunc 393 | requests chan *ads.DeltaDiscoveryRequest 394 | responses chan *ads.DeltaDiscoveryResponse 395 | } 396 | 397 | func (ms *mockStream) Header() (metadata.MD, error) { 398 | return nil, nil 399 | } 400 | 401 | func (ms *mockStream) Trailer() metadata.MD { 402 | return nil 403 | } 404 | 405 | func (ms *mockStream) CloseSend() error { 406 | ms.conn.t.Fatalf("Not supported") 407 | return nil 408 | } 409 | 410 | func (ms *mockStream) Context() context.Context { 411 | return ms.ctx 412 | } 413 | 414 | func (ms *mockStream) SendMsg(msg any) error { 415 | require.IsType(ms.conn.t, (*ads.DeltaDiscoveryRequest)(nil), msg) 416 | select { 417 | case ms.requests <- msg.(*ads.DeltaDiscoveryRequest): 418 | return nil 419 | case <-ms.ctx.Done(): 420 | return ms.ctx.Err() 421 | } 422 | } 423 | 424 | func (ms *mockStream) RecvMsg(msg any) error { 425 | require.IsType(ms.conn.t, (*ads.DeltaDiscoveryResponse)(nil), msg) 426 | select { 427 | case res := <-ms.responses: 428 | proto.Merge(msg.(*ads.DeltaDiscoveryResponse), res) 429 | return nil 430 | case <-ms.ctx.Done(): 431 | return ms.ctx.Err() 432 | } 433 | } 434 | 435 | func (ms *mockStream) respondUpdates( 436 | remainingChunks int, 437 | resources ...*ads.Resource[*Timestamp], 438 | ) string { 439 | return respond[*Timestamp](ms, resources, nil, remainingChunks) 440 | } 441 | 442 | func (ms *mockStream) respondDeletes( 443 | remainingChunks int, 444 | removedResources ...string, 445 | ) string { 446 | return respond[*Timestamp](ms, nil, removedResources, remainingChunks) 447 | } 448 | 449 | func respond[T proto.Message]( 450 | ms *mockStream, 451 | resources []*ads.Resource[T], 452 | removedResources []string, 453 | remainingChunks int, 454 | ) string { 455 | var marshaled []*ads.RawResource 456 | for _, resource := range resources { 457 | raw, err := resource.Marshal() 458 | require.NoError(ms.conn.t, err) 459 | marshaled = append(marshaled, raw) 460 | } 461 | nonce := utils.NewNonce(remainingChunks) 462 | ms.responses <- &ads.DeltaDiscoveryResponse{ 463 | Resources: marshaled, 464 | TypeUrl: utils.GetTypeURL[T](), 465 | RemovedResources: removedResources, 466 | Nonce: nonce, 467 | } 468 | return nonce 469 | } 470 | 471 | func (ms *mockStream) expectACK(nonce string) { 472 | req := <-ms.requests 473 | require.Equal(ms.conn.t, utils.GetTypeURL[*Timestamp](), req.TypeUrl) 474 | require.Equal(ms.conn.t, nonce, req.ResponseNonce) 475 | } 476 | 477 | func expectNACK[T proto.Message](ms *mockStream, nonce string, code codes.Code, errorContains string) { 478 | req := <-ms.requests 479 | require.Equal(ms.conn.t, utils.GetTypeURL[T](), req.TypeUrl) 480 | require.Equal(ms.conn.t, nonce, req.ResponseNonce) 481 | st := status.FromProto(req.GetErrorDetail()) 482 | require.Equal(ms.conn.t, code, st.Code()) 483 | require.ErrorContains(ms.conn.t, st.Err(), errorContains) 484 | } 485 | 486 | func (ms *mockStream) expectSubscriptions(subscriptions ...string) { 487 | req := <-ms.requests 488 | require.Equal(ms.conn.t, utils.GetTypeURL[*Timestamp](), req.TypeUrl) 489 | require.Empty(ms.conn.t, req.ResponseNonce) 490 | require.ElementsMatch(ms.conn.t, subscriptions, req.ResourceNamesSubscribe) 491 | } 492 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package diderot provides a set of utilities to implement an xDS control plan in go. Namely, it 3 | provides two core elements: 4 | 1. The [ADSServer], the implementation of both the SotW and Delta ADS stream variants. 5 | 2. The [Cache], which is an efficient means to store, retrieve and subscribe to xDS resource definitions. 6 | 7 | # ADS Server and Resource Locator 8 | 9 | The [ADSServer] is an implementation of the xDS protocol's various features. It implements both the 10 | Delta and state-of-the-world variants, but abstracts this away completely by only exposing a single 11 | entry point: the [ResourceLocator]. When the server receives a request (be it Delta or SotW), it 12 | will check whether it is an ACK (or a NACK), then invoke the corresponding subscription methods on 13 | the [ResourceLocator]. The locator is simply in charge of invoking Notify on the handler whenever 14 | the resource changes, and the server will relay that resource update to the client using the 15 | corresponding response type. This makes it very easy to implement an xDS control plane without 16 | needing to worry about the finer details of the xDS protocol. 17 | 18 | Most ResourceLocator implementations will likely be a series of [Cache] instances for the 19 | corresponding supported types, which implements the semantics of Subscribe and Resubscribe out of 20 | the box. However, as long as the semantics are respected, implementations may do as they please. For 21 | example, a common pattern is listed in the [xDS spec]: 22 | 23 | For Listener and Cluster resource types, there is also a “wildcard” subscription, which is triggered 24 | when subscribing to the special name *. In this case, the server should use site-specific business 25 | logic to determine the full set of resources that the client is interested in, typically based on 26 | the client’s node identification. 27 | 28 | Instead of invoking subscribing to a backing [Cache] with the wildcard subscription, the said 29 | "business logic" can be implemented in the [ResourceLocator] and wildcard subscriptions can be 30 | transformed into an explicit set of resources. 31 | 32 | # Cache 33 | 34 | This type is the core building block provided by this package. It is effectively a map from 35 | resource name to [ads.Resource] definitions. It provides a way to subscribe to them in order to be 36 | notified whenever they change. For example, the [ads.Endpoint] type (aka 37 | "envoy.config.endpoint.v3.ClusterLoadAssignment") contains the set of IPs that back a specific 38 | [ads.Cluster] ("envoy.config.cluster.v3.Cluster") and is the final step in the standard LDS -> RDS 39 | -> CDS -> EDS Envoy flow. The Cache will store the Endpoint instances that back each cluster, and 40 | Envoy will be able to subscribe to the [ads.Endpoint] resource by providing the correct name when 41 | subscribing. See [diderot.Cache.Subscribe] for additional details on the subscription model. 42 | 43 | It is safe for concurrent use as its concurrency model is per-resource. This means different 44 | goroutines can modify different resources concurrently, and goroutines attempting to modify the 45 | same resource will be synchronized. 46 | 47 | # Cache Priority 48 | 49 | The cache supports a notion of "priority". Concretely, this feature is intended to be used when a 50 | resource definition can come from multiple sources. For example, if resource definitions are being 51 | migrated from one source to another, it would be sane to always use the new source if it is present, 52 | otherwise fall back to the old source. This would be as opposed to simply picking whichever source 53 | defined the resource most recently, as it would mean the resource definition cannot be relied upon 54 | to be stable. [NewPrioritizedCache] returns a slice of instances of their respective types. The 55 | instances all point to the same underlying cache, but at different priorities, where instances that 56 | appear earlier in the slice have a higher priority than those that appear later. If a resource is 57 | defined at priorities p1 and p2 where p1 is a higher priority than p2, subscribers will see the 58 | version that was defined at p1. If the resource is cleared at p1, the cache will fall back to the 59 | definition at p2. This means that a resource is only ever considered fully deleted if it is cleared 60 | at all priority levels. The reason a slice of instances is returned rather than adding a priority 61 | parameter to each function on [Cache] is to avoid complicated configuration or simple bugs where a 62 | resource is being set at an unintended or invalid priority. Instead, the code path where a source is 63 | populating the cache simply receives a reference to the cache and starts writing to it. If the 64 | priority of a source changes in subsequent versions, it can be handled at initialization/startup 65 | instead of requiring any actual code changes to the source itself. 66 | 67 | # xDS TP1 Support 68 | 69 | The notion of glob collections defined in the TP1 proposal is supported natively in the [Cache]. 70 | This means that if resource names are [xdstp:// URNs], they will be automatically added to the 71 | corresponding glob collection, if applicable. These resources are still available for subscription 72 | by their full URN, but will also be available for subscription by subscribing to the parent glob 73 | collection. More details available at [diderot.Cache.Subscribe], [ads.ParseGlobCollectionURL] and 74 | [ads.ParseGlobCollectionURN]. 75 | 76 | [xDS spec]: https://www.envoyproxy.io/docs/envoy/latest/api-docs/xds_protocol#how-the-client-specifies-what-resources-to-return 77 | [xdstp:// URNs]: https://github.com/cncf/xds/blob/main/proposals/TP1-xds-transport-next.md#uri-based-xds-resource-names 78 | */ 79 | package diderot 80 | -------------------------------------------------------------------------------- /examples/quickstart/main.go: -------------------------------------------------------------------------------- 1 | //go:build examples 2 | 3 | package main 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "log" 9 | "net" 10 | "os" 11 | 12 | corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" 13 | discovery "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" 14 | "github.com/linkedin/diderot" 15 | "github.com/linkedin/diderot/ads" 16 | "google.golang.org/grpc" 17 | "google.golang.org/protobuf/proto" 18 | ) 19 | 20 | func main() { 21 | lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", 8080)) 22 | if err != nil { 23 | log.Fatalf("failed to listen: %v", err) 24 | } 25 | 26 | var opts []grpc.ServerOption 27 | // populate your ops 28 | grpcServer := grpc.NewServer(opts...) 29 | 30 | // Use a very simple ResourceLocator that only supports a limited set of types (namely LDS -> RDS -> CDS -> EDS). 31 | locator := NewSimpleResourceLocator(ListenerType, RouteType, ClusterType, EndpointType) 32 | 33 | go PopulateCaches(locator) 34 | 35 | hostname, _ := os.Hostname() 36 | 37 | adsServer := diderot.NewADSServer(locator, 38 | // Send max 10k responses per second. 39 | diderot.WithGlobalResponseRateLimit(10_000), 40 | // Send max one response per type per client every 500ms, to not overload clients. 41 | diderot.WithGranularResponseRateLimit(2), 42 | // Process max 1k requests per second. 43 | diderot.WithRequestRateLimit(1000), 44 | diderot.WithControlPlane(&corev3.ControlPlane{Identifier: hostname}), 45 | ) 46 | discovery.RegisterAggregatedDiscoveryServiceServer(grpcServer, adsServer) 47 | 48 | grpcServer.Serve(lis) 49 | } 50 | 51 | var ( 52 | ListenerType = diderot.TypeOf[*ads.Listener]() 53 | RouteType = diderot.TypeOf[*ads.Route]() 54 | ClusterType = diderot.TypeOf[*ads.Cluster]() 55 | EndpointType = diderot.TypeOf[*ads.Endpoint]() 56 | ) 57 | 58 | // SimpleResourceLocator is a bare-bones [diderot.ResourceLocator] that provides the bare minimum 59 | // functionality. 60 | type SimpleResourceLocator map[string]diderot.RawCache 61 | 62 | func (sl SimpleResourceLocator) Subscribe( 63 | _ context.Context, 64 | typeURL, resourceName string, 65 | handler ads.RawSubscriptionHandler, 66 | ) (unsubscribe func()) { 67 | c, ok := sl[typeURL] 68 | if !ok { 69 | // Do nothing if the given type is not supported 70 | return func() {} 71 | } 72 | diderot.Subscribe(c, resourceName, handler) 73 | return func() { 74 | diderot.Unsubscribe(c, resourceName, handler) 75 | } 76 | } 77 | 78 | // getCache extracts a typed [diderot.Cache] from the given [SimpleResourceLocator]. 79 | func getCache[T proto.Message](sl SimpleResourceLocator) diderot.Cache[T] { 80 | return sl[diderot.TypeOf[T]().URL()].(diderot.Cache[T]) 81 | } 82 | 83 | func (sl SimpleResourceLocator) GetListenerCache() diderot.Cache[*ads.Listener] { 84 | return getCache[*ads.Listener](sl) 85 | } 86 | 87 | func (sl SimpleResourceLocator) GetRouteCache() diderot.Cache[*ads.Route] { 88 | return getCache[*ads.Route](sl) 89 | } 90 | 91 | func (sl SimpleResourceLocator) GetClusterCache() diderot.Cache[*ads.Cluster] { 92 | return getCache[*ads.Cluster](sl) 93 | } 94 | 95 | func (sl SimpleResourceLocator) GetEndpointCache() diderot.Cache[*ads.Endpoint] { 96 | return getCache[*ads.Endpoint](sl) 97 | } 98 | 99 | func NewSimpleResourceLocator(types ...diderot.Type) SimpleResourceLocator { 100 | sl := make(SimpleResourceLocator) 101 | for _, t := range types { 102 | sl[t.URL()] = t.NewCache() 103 | } 104 | return sl 105 | } 106 | 107 | func PopulateCaches(locator SimpleResourceLocator) { 108 | // this is where the business logic of populating the caches should happen. For example, you can read 109 | // the resource definitions from disk, listen to ZK, etc... 110 | } 111 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/linkedin/diderot 2 | 3 | go 1.24.1 4 | 5 | require ( 6 | github.com/envoyproxy/go-control-plane v0.13.4 7 | github.com/envoyproxy/go-control-plane/envoy v1.32.4 8 | github.com/google/go-cmp v0.6.0 9 | github.com/puzpuzpuz/xsync/v4 v4.0.0 10 | github.com/stretchr/testify v1.10.0 11 | golang.org/x/time v0.5.0 12 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241202173237-19429a94021a 13 | google.golang.org/grpc v1.70.0 14 | google.golang.org/protobuf v1.36.4 15 | ) 16 | 17 | require ( 18 | cel.dev/expr v0.19.0 // indirect 19 | cloud.google.com/go/compute/metadata v0.5.2 // indirect 20 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 21 | github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 // indirect 22 | github.com/davecgh/go-spew v1.1.1 // indirect 23 | github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect 24 | github.com/kr/text v0.2.0 // indirect 25 | github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect 26 | github.com/pmezard/go-difflib v1.0.0 // indirect 27 | golang.org/x/net v0.34.0 // indirect 28 | golang.org/x/oauth2 v0.24.0 // indirect 29 | golang.org/x/sync v0.10.0 // indirect 30 | golang.org/x/sys v0.29.0 // indirect 31 | golang.org/x/text v0.21.0 // indirect 32 | google.golang.org/genproto/googleapis/api v0.0.0-20241202173237-19429a94021a // indirect 33 | gopkg.in/yaml.v3 v3.0.1 // indirect 34 | ) 35 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | cel.dev/expr v0.19.0 h1:lXuo+nDhpyJSpWxpPVi5cPUwzKb+dsdOiw6IreM5yt0= 2 | cel.dev/expr v0.19.0/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= 3 | cloud.google.com/go/compute/metadata v0.5.2 h1:UxK4uu/Tn+I3p2dYWTfiX4wva7aYlKixAHn3fyqngqo= 4 | cloud.google.com/go/compute/metadata v0.5.2/go.mod h1:C66sj2AluDcIqakBq/M8lw8/ybHgOZqin2obFxa/E5k= 5 | github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= 6 | github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 7 | github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 h1:QVw89YDxXxEe+l8gU8ETbOasdwEV+avkR75ZzsVV9WI= 8 | github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= 9 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 10 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 11 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 12 | github.com/envoyproxy/go-control-plane v0.13.4 h1:zEqyPVyku6IvWCFwux4x9RxkLOMUL+1vC9xUFv5l2/M= 13 | github.com/envoyproxy/go-control-plane v0.13.4/go.mod h1:kDfuBlDVsSj2MjrLEtRWtHlsWIFcGyB2RMO44Dc5GZA= 14 | github.com/envoyproxy/go-control-plane/envoy v1.32.4 h1:jb83lalDRZSpPWW2Z7Mck/8kXZ5CQAFYVjQcdVIr83A= 15 | github.com/envoyproxy/go-control-plane/envoy v1.32.4/go.mod h1:Gzjc5k8JcJswLjAx1Zm+wSYE20UrLtt7JZMWiWQXQEw= 16 | github.com/envoyproxy/go-control-plane/ratelimit v0.1.0 h1:/G9QYbddjL25KvtKTv3an9lx6VBE2cnb8wp1vEGNYGI= 17 | github.com/envoyproxy/go-control-plane/ratelimit v0.1.0/go.mod h1:Wk+tMFAFbCXaJPzVVHnPgRKdUdwW/KdbRt94AzgRee4= 18 | github.com/envoyproxy/protoc-gen-validate v1.2.1 h1:DEo3O99U8j4hBFwbJfrz9VtgcDfUKS7KJ7spH3d86P8= 19 | github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU= 20 | github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= 21 | github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= 22 | github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= 23 | github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= 24 | github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= 25 | github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= 26 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 27 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 28 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 29 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 30 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 31 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 32 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 33 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 34 | github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= 35 | github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= 36 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 37 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 38 | github.com/puzpuzpuz/xsync/v4 v4.0.0 h1:F1za+MBXzDQtQq+OVgFsojSX4w66rsNDmQNebPFAncA= 39 | github.com/puzpuzpuz/xsync/v4 v4.0.0/go.mod h1:VJDmTCJMBt8igNxnkQd86r+8KUeN1quSfNKu5bLYFQo= 40 | github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= 41 | github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= 42 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 43 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 44 | go.opentelemetry.io/otel v1.32.0 h1:WnBN+Xjcteh0zdk01SVqV55d/m62NJLJdIyb4y/WO5U= 45 | go.opentelemetry.io/otel v1.32.0/go.mod h1:00DCVSB0RQcnzlwyTfqtxSm+DRr9hpYrHjNGiBHVQIg= 46 | go.opentelemetry.io/otel/metric v1.32.0 h1:xV2umtmNcThh2/a/aCP+h64Xx5wsj8qqnkYZktzNa0M= 47 | go.opentelemetry.io/otel/metric v1.32.0/go.mod h1:jH7CIbbK6SH2V2wE16W05BHCtIDzauciCRLoc/SyMv8= 48 | go.opentelemetry.io/otel/sdk v1.32.0 h1:RNxepc9vK59A8XsgZQouW8ue8Gkb4jpWtJm9ge5lEG4= 49 | go.opentelemetry.io/otel/sdk v1.32.0/go.mod h1:LqgegDBjKMmb2GC6/PrTnteJG39I8/vJCAP9LlJXEjU= 50 | go.opentelemetry.io/otel/sdk/metric v1.32.0 h1:rZvFnvmvawYb0alrYkjraqJq0Z4ZUJAiyYCU9snn1CU= 51 | go.opentelemetry.io/otel/sdk/metric v1.32.0/go.mod h1:PWeZlq0zt9YkYAp3gjKZ0eicRYvOh1Gd+X99x6GHpCQ= 52 | go.opentelemetry.io/otel/trace v1.32.0 h1:WIC9mYrXf8TmY/EXuULKc8hR17vE+Hjv2cssQDe03fM= 53 | go.opentelemetry.io/otel/trace v1.32.0/go.mod h1:+i4rkvCraA+tG6AzwloGaCtkx53Fa+L+V8e9a7YvhT8= 54 | golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= 55 | golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= 56 | golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= 57 | golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= 58 | golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= 59 | golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 60 | golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= 61 | golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 62 | golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= 63 | golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= 64 | golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= 65 | golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= 66 | google.golang.org/genproto/googleapis/api v0.0.0-20241202173237-19429a94021a h1:OAiGFfOiA0v9MRYsSidp3ubZaBnteRUyn3xB2ZQ5G/E= 67 | google.golang.org/genproto/googleapis/api v0.0.0-20241202173237-19429a94021a/go.mod h1:jehYqy3+AhJU9ve55aNOaSml7wUXjF9x6z2LcCfpAhY= 68 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241202173237-19429a94021a h1:hgh8P4EuoxpsuKMXX/To36nOFD7vixReXgn8lPGnt+o= 69 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241202173237-19429a94021a/go.mod h1:5uTbfoYQed2U9p3KIj2/Zzm02PYhndfdmML0qC3q3FU= 70 | google.golang.org/grpc v1.70.0 h1:pWFv03aZoHzlRKHWicjsZytKAiYCtNS0dHbXnIdq7jQ= 71 | google.golang.org/grpc v1.70.0/go.mod h1:ofIJqVKDXx/JiXrwr2IG4/zwdH9txy3IlF40RmcJSQw= 72 | google.golang.org/protobuf v1.36.4 h1:6A3ZDJHn/eNqc1i+IdefRzy/9PokBTPvcqMySR7NNIM= 73 | google.golang.org/protobuf v1.36.4/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= 74 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 75 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 76 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 77 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 78 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 79 | -------------------------------------------------------------------------------- /internal/cache/glob_collection.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | 7 | "github.com/linkedin/diderot/ads" 8 | "github.com/linkedin/diderot/internal/utils" 9 | "google.golang.org/protobuf/proto" 10 | ) 11 | 12 | // A globCollection is used to track all the resources in the collection. 13 | type globCollection[T proto.Message] struct { 14 | // The URL that corresponds to this collection, represented as the raw string rather than a 15 | // GlobCollectionURL to avoid repeated redundant calls to GlobCollectionURL.String. 16 | url string 17 | 18 | // The current subscribers to this collection. 19 | subscribers SubscriberSet[T] 20 | // Protects values and nonNilValueNames. 21 | lock sync.RWMutex 22 | // The set of values in the collection, used by new subscribers to subscribe to all values. 23 | values utils.Set[*WatchableValue[T]] 24 | // The set of all non-nil resource names in this collection. Used to track whether a collection is 25 | // empty. Note that a collection can be empty even if values is non-empty since values that are 26 | // explicitly subscribed to are kept in the collection/cache to track the subscription in case the 27 | // value returns. 28 | nonNilValueNames utils.Set[string] 29 | } 30 | 31 | func newGlobCollection[T proto.Message](url string) *globCollection[T] { 32 | return &globCollection[T]{ 33 | url: url, 34 | values: make(utils.Set[*WatchableValue[T]]), 35 | nonNilValueNames: make(utils.Set[string]), 36 | } 37 | } 38 | 39 | func (g *globCollection[T]) hasNoValuesOrSubscribersNoLock() bool { 40 | return len(g.values) == 0 && g.subscribers.Size() == 0 41 | } 42 | 43 | // hasNoValuesOrSubscribers returns true if the collection is empty and has no subscribers. 44 | func (g *globCollection[T]) hasNoValuesOrSubscribers() bool { 45 | g.lock.RLock() 46 | defer g.lock.RUnlock() 47 | 48 | return g.hasNoValuesOrSubscribersNoLock() 49 | } 50 | 51 | // resourceSet notifies the collection that the given resource has been created. 52 | func (g *globCollection[T]) resourceSet(name string) { 53 | g.lock.Lock() 54 | defer g.lock.Unlock() 55 | 56 | g.nonNilValueNames.Add(name) 57 | } 58 | 59 | // resourceCleared notifies the collection that the given resource has been cleared. If there are no 60 | // remaining non-nil values in the collection (or no values at all), the subscribers are all notified 61 | // that the collection has been deleted. 62 | func (g *globCollection[T]) resourceCleared(name string) { 63 | g.lock.Lock() 64 | defer g.lock.Unlock() 65 | 66 | g.nonNilValueNames.Remove(name) 67 | 68 | if len(g.nonNilValueNames) > 0 { 69 | return 70 | } 71 | 72 | deletedAt := time.Now() 73 | 74 | for handler, subscribedAt := range g.subscribers.Iterator() { 75 | handler.Notify(g.url, nil, ads.SubscriptionMetadata{ 76 | SubscribedAt: subscribedAt, 77 | ModifiedAt: deletedAt, 78 | CachedAt: deletedAt, 79 | GlobCollectionURL: g.url, 80 | }) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /internal/cache/glob_collections_map.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "log/slog" 5 | "sync" 6 | "time" 7 | 8 | "github.com/linkedin/diderot/ads" 9 | "google.golang.org/protobuf/proto" 10 | ) 11 | 12 | // GlobCollectionsMap used to map individual GlobCollectionURL to their corresponding globCollection. 13 | // This uses a ResourceMap under the hood because it has similar semantics to cache entries: 14 | // 1. A globCollection is created lazily, either when an entry for that collection is created, or a 15 | // subscription to that collection is made. 16 | // 2. A globCollection is only deleted once all subscribers have unsubscribed and the collection is 17 | // empty. Crucially, a collection can be empty but will remain in the cache as long as some 18 | // subscribers remain subscribed. 19 | type GlobCollectionsMap[T proto.Message] struct { 20 | collections *ResourceMap[ads.GlobCollectionURL, *globCollection[T]] 21 | } 22 | 23 | func NewGlobCollectionsMap[T proto.Message]() *GlobCollectionsMap[T] { 24 | return &GlobCollectionsMap[T]{ 25 | collections: NewResourceMap[ads.GlobCollectionURL, *globCollection[T]](), 26 | } 27 | } 28 | 29 | // createOrModifyCollection gets or creates the globCollection for the given GlobCollectionURL, and 30 | // executes the given function on it. 31 | func (gcm *GlobCollectionsMap[T]) createOrModifyCollection( 32 | gcURL ads.GlobCollectionURL, 33 | f func(collection *globCollection[T]), 34 | ) *globCollection[T] { 35 | gc, _ := gcm.collections.Compute( 36 | gcURL, 37 | func(gcURL ads.GlobCollectionURL) *globCollection[T] { 38 | gc := newGlobCollection[T](gcURL.String()) 39 | slog.Debug("Created collection", "url", gcURL) 40 | return gc 41 | }, 42 | f, 43 | ) 44 | return gc 45 | } 46 | 47 | // PutValueInCollection creates the glob collection if it was not already created, and puts the given 48 | // value in it. 49 | func (gcm *GlobCollectionsMap[T]) PutValueInCollection(gcURL ads.GlobCollectionURL, value *WatchableValue[T]) { 50 | gcm.createOrModifyCollection(gcURL, func(collection *globCollection[T]) { 51 | collection.lock.Lock() 52 | defer collection.lock.Unlock() 53 | 54 | value.globCollection = collection 55 | collection.values.Add(value) 56 | value.SubscriberSets[GlobSubscription] = &collection.subscribers 57 | }) 58 | } 59 | 60 | // RemoveValueFromCollection removes the given value from the collection. If the collection becomes 61 | // empty as a result, it is removed from the map. 62 | func (gcm *GlobCollectionsMap[T]) RemoveValueFromCollection(gcURL ads.GlobCollectionURL, value *WatchableValue[T]) { 63 | gcm.deleteCollectionIfEmpty(gcURL, func(collection *globCollection[T]) { 64 | collection.lock.Lock() 65 | defer collection.lock.Unlock() 66 | 67 | collection.values.Remove(value) 68 | }) 69 | } 70 | 71 | // Subscribe creates or gets the corresponding collection for the given URL using 72 | // createOrModifyCollection. It adds the given handler as a subscriber to the collection, then 73 | // iterates through all the values in the collection, notifying the handler for each value. If the 74 | // collection is empty, the handler will be notified that the resource is deleted. See the 75 | // documentation on [WatchableValue.NotifyHandlerAfterSubscription] for more insight on the returned 76 | // [sync.WaitGroup] slice. 77 | func (gcm *GlobCollectionsMap[T]) Subscribe( 78 | gcURL ads.GlobCollectionURL, handler ads.SubscriptionHandler[T], 79 | ) (wgs []*sync.WaitGroup) { 80 | var subscribedAt time.Time 81 | var version SubscriberSetVersion 82 | collection := gcm.createOrModifyCollection(gcURL, func(collection *globCollection[T]) { 83 | subscribedAt, version = collection.subscribers.Subscribe(handler) 84 | }) 85 | 86 | collection.lock.RLock() 87 | defer collection.lock.RUnlock() 88 | 89 | if len(collection.nonNilValueNames) == 0 { 90 | handler.Notify(collection.url, nil, ads.SubscriptionMetadata{ 91 | SubscribedAt: subscribedAt, 92 | ModifiedAt: time.Time{}, 93 | CachedAt: time.Time{}, 94 | GlobCollectionURL: collection.url, 95 | }) 96 | } else { 97 | for v := range collection.values { 98 | wg := v.NotifyHandlerAfterSubscription(handler, GlobSubscription, subscribedAt, version) 99 | if wg != nil { 100 | wgs = append(wgs, wg) 101 | } 102 | } 103 | } 104 | 105 | return wgs 106 | } 107 | 108 | // Unsubscribe invokes globCollection.unsubscribe on the collection for the given URL, if it exists. 109 | // If, as a result, the collection becomes empty, it invokes deleteCollectionIfEmpty. 110 | func (gcm *GlobCollectionsMap[T]) Unsubscribe(gcURL ads.GlobCollectionURL, handler ads.SubscriptionHandler[T]) { 111 | gcm.deleteCollectionIfEmpty(gcURL, func(collection *globCollection[T]) { 112 | collection.subscribers.Unsubscribe(handler) 113 | }) 114 | } 115 | 116 | // deleteCollectionIfEmpty attempts to completely remove the collection from the map, if and only if 117 | // there are no more subscribers and the collection is empty. 118 | func (gcm *GlobCollectionsMap[T]) deleteCollectionIfEmpty(gcURL ads.GlobCollectionURL, op func(collection *globCollection[T])) { 119 | gcm.collections.ComputeDeletion(gcURL, func(collection *globCollection[T]) bool { 120 | op(collection) 121 | 122 | empty := collection.hasNoValuesOrSubscribers() 123 | if empty { 124 | slog.Debug("Deleting collection", "url", gcURL) 125 | } 126 | 127 | return empty 128 | }) 129 | } 130 | 131 | // IsSubscribed checks if the given handler is subscribed to the collection. 132 | func (gcm *GlobCollectionsMap[T]) IsSubscribed(gcURL ads.GlobCollectionURL, handler ads.SubscriptionHandler[T]) (subscribed bool) { 133 | gcm.collections.ComputeIfPresent(gcURL, func(collection *globCollection[T]) { 134 | // Locking is not required here, as SubscriberSet is safe for concurrent access. 135 | subscribed = collection.subscribers.IsSubscribed(handler) 136 | }) 137 | return subscribed 138 | } 139 | 140 | // Size returns the size of the glob collection for the given URL, or 0 if no such collection exists. 141 | func (gcm *GlobCollectionsMap[T]) Size(gcURL ads.GlobCollectionURL) (size int) { 142 | gcm.collections.ComputeIfPresent(gcURL, func(collection *globCollection[T]) { 143 | collection.lock.RLock() 144 | defer collection.lock.RUnlock() 145 | size = len(collection.nonNilValueNames) 146 | }) 147 | return size 148 | } 149 | -------------------------------------------------------------------------------- /internal/cache/resource_map.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "iter" 5 | 6 | "github.com/puzpuzpuz/xsync/v4" 7 | ) 8 | 9 | // ResourceMap is a concurrency-safe map. It deliberately does not expose bare Get or Put methods as 10 | // its concurrency model is based on the assumption that access to the backing values must be 11 | // strictly synchronized. Instead, all operations should be executed through the various Compute 12 | // methods. 13 | type ResourceMap[K comparable, V any] xsync.Map[K, V] 14 | 15 | func NewResourceMap[K comparable, V any]() *ResourceMap[K, V] { 16 | return (*ResourceMap[K, V])(xsync.NewMap[K, V]()) 17 | } 18 | 19 | // Compute first creates the value for the given key using the given function if no corresponding 20 | // entry exists, then it executes the given compute function. It returns the value itself, and a 21 | // boolean indicating whether the value was created. 22 | func (m *ResourceMap[K, V]) Compute( 23 | key K, 24 | newValue func(key K) V, 25 | compute func(value V), 26 | ) (v V, created bool) { 27 | v, _ = (*xsync.Map[K, V])(m).Compute(key, func(v V, loaded bool) (_ V, op xsync.ComputeOp) { 28 | if !loaded { 29 | v = newValue(key) 30 | op = xsync.UpdateOp 31 | created = true 32 | } 33 | compute(v) 34 | return v, op 35 | }) 36 | return v, created 37 | } 38 | 39 | // ComputeIfPresent invokes the given function only if a corresponding entry exists in the map for 40 | // the given key. 41 | func (m *ResourceMap[K, V]) ComputeIfPresent(key K, f func(value V)) (wasPresent bool) { 42 | (*xsync.Map[K, V])(m).Compute(key, func(oldValue V, loaded bool) (_ V, op xsync.ComputeOp) { 43 | if !loaded { 44 | return oldValue, op 45 | } 46 | 47 | f(oldValue) 48 | wasPresent = true 49 | return oldValue, op 50 | }) 51 | return wasPresent 52 | } 53 | 54 | // ComputeDeletion loads the entry from the map if it still exists, then executes the given condition 55 | // function with the value. If the condition returns true, the entry is deleted from the map, 56 | // otherwise nothing happens. As a "compute" function, the condition is executed synchronously, in 57 | // other words, it is guaranteed that no other "compute" functions are executing on that entry. 58 | func (m *ResourceMap[K, V]) ComputeDeletion(key K, condition func(value V) (deleteEntry bool)) (deleted bool) { 59 | (*xsync.Map[K, V])(m).Compute(key, func(oldValue V, loaded bool) (_ V, op xsync.ComputeOp) { 60 | if !loaded { 61 | return oldValue, op 62 | } 63 | 64 | if condition(oldValue) { 65 | op = xsync.DeleteOp 66 | deleted = true 67 | } 68 | return oldValue, op 69 | }) 70 | 71 | return deleted 72 | } 73 | 74 | // Range returns an [iter.Seq2] that will iterate over all entries in this map. 75 | func (m *ResourceMap[K, V]) Range() iter.Seq2[K, V] { 76 | return (*xsync.Map[K, V])(m).Range 77 | } 78 | 79 | // Size returns the current number of entries in the map. 80 | func (m *ResourceMap[K, V]) Size() int { 81 | return (*xsync.Map[K, V])(m).Size() 82 | } 83 | -------------------------------------------------------------------------------- /internal/cache/subscriber_set.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "iter" 5 | "sync" 6 | "sync/atomic" 7 | "time" 8 | 9 | "github.com/linkedin/diderot/ads" 10 | "google.golang.org/protobuf/proto" 11 | ) 12 | 13 | // SubscriberSetVersion is a monotonically increasing counter that tracks how many times subscribers 14 | // have been added to a given SubscriberSet. This means a subscriber can check whether they are in a 15 | // SubscriberSet by storing the version returned by SubscriberSet.Subscribe and comparing it against 16 | // the version returned by SubscriberSet.Iterator. 17 | type SubscriberSetVersion uint64 18 | 19 | // SubscriberSet is a concurrency-safe data structure that stores a set of unique subscribers. It is 20 | // specifically designed to support wildcard and glob subscriptions such that they can be shared by 21 | // multiple watchableValues instead of requiring each WatchableValue to store each subscriber. After 22 | // subscribing to a given value, the SubscriptionHandler is supposed to be notified of the current 23 | // value immediately, which usually simply means reading WatchableValue.currentValue and notifying 24 | // the handler. However, it is possible that the notification loop for the WatchableValue is already 25 | // running, and it could result in a double notification. To avoid this, this data structure 26 | // introduces a notion of versioning. This way, the notification loop can record which version it is 27 | // about to iterate over (in WatchableValue.lastSeenSubscriberSetVersions) such that subscribers can 28 | // determine whether the loop will notify them and avoid the double notification. This is done by 29 | // recording the version returned by SubscriberSet.Subscribe and checking whether it's equal to or 30 | // smaller than the version in WatchableValue.lastSeenSubscriberSetVersions. 31 | // 32 | // The implementation uses a sync.Map to store and iterate over the subscribers. In this case it's 33 | // impossible to use a normal map since the subscriber set will be iterated over frequently. However, 34 | // sync.Map provides no guarantees about what happens if the map is modified while another goroutine 35 | // is iterating over the entries. Specifically, if an entry is added during the iteration, the 36 | // iterator may or may not actually yield the new entry, which means the iterator may yield an entry 37 | // that was added _after_ Iterator was invoked, violating the Iterator contract that it will only 38 | // yield entries that were added before. To get around this, the returned iterator simply records the 39 | // version at which it was initially created, and drops entries that have a greater version, making 40 | // it always consistent. 41 | type SubscriberSet[T proto.Message] struct { 42 | // Protects entry creation in the set. 43 | lock sync.Mutex 44 | // Maps SubscriptionHandler instances to the subscriber instance containing the metadata. 45 | subscribers sync.Map // Real type: map[SubscriptionHandler[T]]*subscriber 46 | // The current subscriber set version. 47 | version SubscriberSetVersion 48 | // Stores the current number of subscribers. 49 | size atomic.Int64 50 | } 51 | 52 | type subscriber struct { 53 | subscribedAt time.Time 54 | version SubscriberSetVersion 55 | } 56 | 57 | // IsSubscribed checks whether the given handler is subscribed to this set. 58 | func (m *SubscriberSet[T]) IsSubscribed(handler ads.SubscriptionHandler[T]) bool { 59 | if m == nil { 60 | return false 61 | } 62 | 63 | _, ok := m.subscribers.Load(handler) 64 | return ok 65 | } 66 | 67 | // Subscribe registers the given SubscriptionHandler as a subscriber and returns the time and version 68 | // at which the subscription was processed. The returned version can be compared against the version 69 | // returned by Iterator to check whether the given handler is present in the iterator. 70 | func (m *SubscriberSet[T]) Subscribe(handler ads.SubscriptionHandler[T]) (time.Time, SubscriberSetVersion) { 71 | m.lock.Lock() 72 | defer m.lock.Unlock() 73 | 74 | m.version++ 75 | s := &subscriber{ 76 | subscribedAt: timeProvider(), 77 | version: m.version, 78 | } 79 | _, loaded := m.subscribers.Swap(handler, s) 80 | if !loaded { 81 | m.size.Add(1) 82 | } 83 | 84 | return s.subscribedAt, s.version 85 | } 86 | 87 | // Unsubscribe removes the given handler from the set, and returns whether the set is now empty as a 88 | // result of this unsubscription. 89 | func (m *SubscriberSet[T]) Unsubscribe(handler ads.SubscriptionHandler[T]) (empty bool) { 90 | _, loaded := m.subscribers.LoadAndDelete(handler) 91 | if !loaded { 92 | return m.size.Load() == 0 93 | } 94 | 95 | return m.size.Add(-1) == 0 96 | } 97 | 98 | // Size returns the number of subscribers in the set. For convenience, returns 0 if the receiver is 99 | // nil. 100 | func (m *SubscriberSet[T]) Size() int { 101 | if m == nil { 102 | return 0 103 | } 104 | return int(m.size.Load()) 105 | } 106 | 107 | // IsEmpty is a convenience function that checks whether the set is empty˜. 108 | func (m *SubscriberSet[T]) IsEmpty() bool { 109 | return m.Size() == 0 110 | } 111 | 112 | // Version returns the current version of this set. Invoking [SubscriberSet.SnapshotIterator] with 113 | // the returned version will only yield subscribers added to this set at or before that version. 114 | func (m *SubscriberSet[T]) Version() SubscriberSetVersion { 115 | if m == nil { 116 | return 0 117 | } 118 | 119 | m.lock.Lock() 120 | defer m.lock.Unlock() 121 | return m.version 122 | } 123 | 124 | type SubscriberSetIterator[T proto.Message] iter.Seq2[ads.SubscriptionHandler[T], time.Time] 125 | 126 | // Iterator returns a [SubscriberSetIterator] that will iterate over all the subscribers currently in 127 | // the set. 128 | func (m *SubscriberSet[T]) Iterator() SubscriberSetIterator[T] { 129 | return func(yield func(ads.SubscriptionHandler[T], time.Time) bool) { 130 | if m == nil { 131 | return 132 | } 133 | 134 | for key, value := range m.subscribers.Range { 135 | if !yield(key.(ads.SubscriptionHandler[T]), value.(*subscriber).subscribedAt) { 136 | break 137 | } 138 | } 139 | } 140 | } 141 | 142 | // SnapshotIterator returns a [SubscriberSetIterator] that will only iterate over the subscribers 143 | // that were added before or at the given version. 144 | func (m *SubscriberSet[T]) SnapshotIterator(v SubscriberSetVersion) SubscriberSetIterator[T] { 145 | return func(yield func(ads.SubscriptionHandler[T], time.Time) bool) { 146 | if m == nil { 147 | return 148 | } 149 | 150 | for key, value := range m.subscribers.Range { 151 | s := value.(*subscriber) 152 | if s.version > v { 153 | continue 154 | } 155 | 156 | if !yield(key.(ads.SubscriptionHandler[T]), s.subscribedAt) { 157 | break 158 | } 159 | } 160 | } 161 | } 162 | -------------------------------------------------------------------------------- /internal/cache/subscriber_set_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/linkedin/diderot/ads" 8 | "github.com/stretchr/testify/require" 9 | . "google.golang.org/protobuf/types/known/timestamppb" 10 | ) 11 | 12 | type noopHandler byte 13 | 14 | func (*noopHandler) Notify(string, *ads.Resource[*Timestamp], ads.SubscriptionMetadata) {} 15 | 16 | type iterateArgs struct { 17 | handler ads.SubscriptionHandler[*Timestamp] 18 | subscribedAt time.Time 19 | } 20 | 21 | func checkIterate(t *testing.T, m *SubscriberSet[*Timestamp], expectedV SubscriberSetVersion, expectedArgs ...iterateArgs) { 22 | require.Equal(t, m.Size(), len(expectedArgs)) 23 | v := m.Version() 24 | require.Equal(t, expectedV, v) 25 | 26 | var actualArgs []iterateArgs 27 | 28 | for handler, subscribedAt := range m.SnapshotIterator(v) { 29 | actualArgs = append(actualArgs, iterateArgs{ 30 | handler: handler, 31 | subscribedAt: subscribedAt, 32 | }) 33 | } 34 | require.ElementsMatch(t, expectedArgs, actualArgs) 35 | } 36 | 37 | func TestSubscriberMap(t *testing.T) { 38 | s := new(SubscriberSet[*Timestamp]) 39 | checkIterate(t, s, 0) 40 | 41 | h1 := new(noopHandler) 42 | sAt1, v := s.Subscribe(h1) 43 | require.Equal(t, SubscriberSetVersion(1), v) 44 | require.True(t, s.IsSubscribed(h1)) 45 | 46 | checkIterate(t, s, 1, 47 | iterateArgs{ 48 | handler: h1, 49 | subscribedAt: sAt1, 50 | }, 51 | ) 52 | 53 | h2 := new(noopHandler) 54 | sAt2, v := s.Subscribe(h2) 55 | require.NotEqual(t, sAt1, sAt2) 56 | require.Equal(t, SubscriberSetVersion(2), v) 57 | require.True(t, s.IsSubscribed(h2)) 58 | 59 | checkIterate(t, s, 2, 60 | iterateArgs{ 61 | handler: h1, 62 | subscribedAt: sAt1, 63 | }, 64 | iterateArgs{ 65 | handler: h2, 66 | subscribedAt: sAt2, 67 | }, 68 | ) 69 | 70 | sAt3, v := s.Subscribe(h1) 71 | require.NotEqual(t, sAt1, sAt3) 72 | require.Equal(t, SubscriberSetVersion(3), v) 73 | require.True(t, s.IsSubscribed(h1)) 74 | 75 | checkIterate(t, s, 3, 76 | iterateArgs{ 77 | handler: h2, 78 | subscribedAt: sAt2, 79 | }, 80 | iterateArgs{ 81 | handler: h1, 82 | subscribedAt: sAt3, 83 | }, 84 | ) 85 | 86 | s.Unsubscribe(h1) 87 | require.False(t, s.IsSubscribed(h1)) 88 | checkIterate(t, s, 3, 89 | iterateArgs{ 90 | handler: h2, 91 | subscribedAt: sAt2, 92 | }) 93 | 94 | s.Unsubscribe(h2) 95 | require.False(t, s.IsSubscribed(h2)) 96 | checkIterate(t, s, 3) 97 | } 98 | -------------------------------------------------------------------------------- /internal/cache/subscription_type.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | // subscriptionType describes the ways a client can subscribe to a resource. 4 | type subscriptionType byte 5 | 6 | // The following subscriptionType constants define the ways a client can subscribe to a resource. See 7 | // RawCache.Subscribe for additional details. 8 | const ( 9 | // An ExplicitSubscription means the client subscribed to a resource by explicit providing its name. 10 | ExplicitSubscription = subscriptionType(iota) 11 | // A GlobSubscription means the client subscribed to a resource by specifying its parent glob 12 | // collection URL, implicitly subscribing it to all the resources that are part of the collection. 13 | GlobSubscription 14 | // A WildcardSubscription means the client subscribed to a resource by specifying the wildcard 15 | // (ads.WildcardSubscription), implicitly subscribing it to all resources in the cache. 16 | WildcardSubscription 17 | 18 | subscriptionTypes = iota 19 | ) 20 | 21 | func (t subscriptionType) isImplicit() bool { 22 | return t != ExplicitSubscription 23 | } 24 | -------------------------------------------------------------------------------- /internal/cache/subscription_type_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | func TestSubscriptionType(t *testing.T) { 10 | require.False(t, ExplicitSubscription.isImplicit()) 11 | require.True(t, GlobSubscription.isImplicit()) 12 | require.True(t, WildcardSubscription.isImplicit()) 13 | } 14 | -------------------------------------------------------------------------------- /internal/cache/watchable_value_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | "time" 7 | 8 | "github.com/linkedin/diderot/ads" 9 | "github.com/linkedin/diderot/testutils" 10 | "google.golang.org/protobuf/types/known/timestamppb" 11 | ) 12 | 13 | // Benchmarks the actual notification loop. There was a problem where the loop was leaking something 14 | // to heap, causing it to add GC pressure. 15 | func BenchmarkNotificationLoop(b *testing.B) { 16 | v := NewValue[*timestamppb.Timestamp]("foo", 1) 17 | for b.Loop() { 18 | v.notificationLoop() 19 | } 20 | } 21 | 22 | // This benchmarks the worst case scenario, where a goroutine is re-created every time to deliver the 23 | // notification, instead of being reused because the resource update caused the loop to start over. 24 | func BenchmarkValueSetClear(b *testing.B) { 25 | SetTimeProvider(func() (t time.Time) { return t }) 26 | b.Cleanup(func() { 27 | SetTimeProvider(time.Now) 28 | }) 29 | 30 | var done sync.WaitGroup 31 | 32 | done.Add(1) 33 | simpleHandler := testutils.NewSubscriptionHandler( 34 | func(name string, r *ads.Resource[*timestamppb.Timestamp], _ ads.SubscriptionMetadata) { 35 | done.Done() 36 | }, 37 | ) 38 | v := NewValue[*timestamppb.Timestamp]("foo", 1) 39 | v.Subscribe(simpleHandler) 40 | done.Wait() 41 | 42 | r := ads.NewResource("foo", "0", timestamppb.Now()) 43 | for b.Loop() { 44 | done.Add(1) 45 | v.Set(0, r, time.Time{}) 46 | done.Wait() 47 | 48 | done.Add(1) 49 | v.Clear(0, time.Time{}) 50 | done.Wait() 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /internal/client/watchers.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "iter" 7 | "maps" 8 | "sync" 9 | 10 | "github.com/linkedin/diderot/ads" 11 | "github.com/linkedin/diderot/internal/utils" 12 | "google.golang.org/protobuf/proto" 13 | ) 14 | 15 | // Watcher is a copy of the interface of the same name in the root package, to avoid import cycles. 16 | type Watcher[T proto.Message] interface { 17 | Notify(resources iter.Seq2[string, *ads.Resource[T]]) error 18 | } 19 | 20 | // RawResourceHandler is a non-generic interface implemented by [ResourceHandler]. Used by the 21 | // non-generic [github.com/linkedin/diderot.ADSClient]. 22 | type RawResourceHandler interface { 23 | // AllSubscriptions returns a sequence of all the subscriptions in this client. 24 | AllSubscriptions() iter.Seq[string] 25 | // HandleResponses should be called whenever responses are received. Accepts a slice of responses, 26 | // since the client may support chunking. The given boolean parameter indicates whether this is the 27 | // set of responses received for a stream. This is used to determine whether any resources were 28 | // deleted while the client was disconnected. 29 | HandleResponses(isFirst bool, responses []*ads.DeltaDiscoveryResponse) error 30 | } 31 | 32 | // Ensure that [ResourceHandler] implements the [RawResourceHandler] interface. 33 | var _ RawResourceHandler = (*ResourceHandler[proto.Message])(nil) 34 | 35 | // ResourceHandler implements the core logic of managing notifications for watchers. 36 | type ResourceHandler[T proto.Message] struct { 37 | lock sync.Mutex 38 | 39 | // All the resources currently known by this client. 40 | resources map[string]*ads.Resource[T] 41 | 42 | // Maps resource name to watchers of the resource. 43 | subscriptions map[string]*subscription[T] 44 | // Maps glob collection URL to watchers of the collection. 45 | globSubscriptions map[ads.GlobCollectionURL]*globSubscription[T] 46 | // Contains the set of wildcard watchers, if present. 47 | wildcardSubscription *subscription[T] 48 | } 49 | 50 | // resolve returns an [iter.Seq2] that resolves the resources for the given sequence of resource 51 | // names. The [ads.Resource] will be nil if no such resource is known. 52 | func (h *ResourceHandler[T]) resolve(in iter.Seq[string]) iter.Seq2[string, *ads.Resource[T]] { 53 | return func(yield func(string, *ads.Resource[T]) bool) { 54 | for name := range in { 55 | if !yield(name, h.resources[name]) { 56 | return 57 | } 58 | } 59 | } 60 | } 61 | 62 | func (h *ResourceHandler[T]) resolveSingle(name string) iter.Seq2[string, *ads.Resource[T]] { 63 | return func(yield func(string, *ads.Resource[T]) bool) { 64 | yield(name, h.resources[name]) 65 | } 66 | } 67 | 68 | // setResource updates the map of known resources. Returns a boolean indicating whether the resource 69 | // was actually changed. 70 | func (h *ResourceHandler[T]) setResource(name string, resource *ads.Resource[T]) (updated bool) { 71 | var previous *ads.Resource[T] 72 | previous, ok := h.resources[name] 73 | if !ok && resource == nil { 74 | // Ignore deletions for unknown resources 75 | return false 76 | } 77 | if ok && resource.Equals(previous) { 78 | // Ignore updates identical to the most recently seen resource. 79 | return false 80 | } 81 | 82 | if resource == nil { 83 | delete(h.resources, name) 84 | } else { 85 | h.resources[name] = resource 86 | } 87 | return true 88 | } 89 | 90 | func NewResourceHandler[T proto.Message]() *ResourceHandler[T] { 91 | return &ResourceHandler[T]{ 92 | resources: make(map[string]*ads.Resource[T]), 93 | subscriptions: make(map[string]*subscription[T]), 94 | globSubscriptions: make(map[ads.GlobCollectionURL]*globSubscription[T]), 95 | } 96 | } 97 | 98 | // AddWatcher registers the given [Watcher] against the given resource name. The watcher will be 99 | // notified whenever the resource is created, updated or deleted. The returned boolean indicates 100 | // whether the watcher was a new registration, or was already previously registered. If a value for 101 | // the given resource is already known, the watcher is immediately notified. 102 | func (h *ResourceHandler[T]) AddWatcher(name string, w Watcher[T]) bool { 103 | h.lock.Lock() 104 | defer h.lock.Unlock() 105 | 106 | // contains the set of watchers to update 107 | var watchers utils.Set[Watcher[T]] 108 | // set if a value for the resource is already known. 109 | var resources iter.Seq2[string, *ads.Resource[T]] 110 | 111 | if name == ads.WildcardSubscription { 112 | if h.wildcardSubscription == nil { 113 | h.wildcardSubscription = newSubscription[T]() 114 | } 115 | watchers = h.wildcardSubscription.watchers 116 | if h.wildcardSubscription.initialized { 117 | resources = maps.All(h.resources) 118 | } 119 | } else if gcURL, err := ads.ParseGlobCollectionURL[T](name); err == nil { 120 | globSub, ok := h.globSubscriptions[gcURL] 121 | if !ok { 122 | globSub = &globSubscription[T]{ 123 | subscription: *newSubscription[T](), 124 | entries: make(utils.Set[string]), 125 | } 126 | h.globSubscriptions[gcURL] = globSub 127 | } 128 | watchers = globSub.watchers 129 | if globSub.initialized { 130 | resources = h.resolve(globSub.entries.Values()) 131 | } 132 | } else { 133 | sub, ok := h.subscriptions[name] 134 | if !ok { 135 | sub = newSubscription[T]() 136 | h.subscriptions[name] = sub 137 | } 138 | // In the event that there is already data from another subscription for this specific resource, 139 | // immediately satisfy the watcher. 140 | _, sub.initialized = h.resources[name] 141 | watchers = sub.watchers 142 | if sub.initialized { 143 | resources = h.resolveSingle(name) 144 | } 145 | } 146 | 147 | if resources != nil { 148 | _ = w.Notify(resources) 149 | } 150 | 151 | return watchers.Add(w) 152 | } 153 | 154 | func (h *ResourceHandler[T]) AllSubscriptions() iter.Seq[string] { 155 | return func(yield func(string) bool) { 156 | h.lock.Lock() 157 | defer h.lock.Unlock() 158 | 159 | for k := range h.subscriptions { 160 | if !yield(k) { 161 | return 162 | } 163 | } 164 | for k := range h.globSubscriptions { 165 | if !yield(k.String()) { 166 | return 167 | } 168 | } 169 | if h.wildcardSubscription != nil { 170 | yield(ads.WildcardSubscription) 171 | } 172 | } 173 | } 174 | 175 | func (h *ResourceHandler[T]) HandleResponses(isFirst bool, responses []*ads.DeltaDiscoveryResponse) error { 176 | h.lock.Lock() 177 | defer h.lock.Unlock() 178 | 179 | var errs []error 180 | addErr := func(err error) { 181 | errs = append(errs, err) 182 | } 183 | notifyWatchers := func(sub *subscription[T], seq iter.Seq2[string, *ads.Resource[T]]) { 184 | for w := range sub.watchers { 185 | err := w.Notify(seq) 186 | if err != nil { 187 | addErr(err) 188 | } 189 | } 190 | } 191 | 192 | totalAddedResources := 0 193 | totalDeletedResources := 0 194 | for _, response := range responses { 195 | totalAddedResources += len(response.Resources) 196 | totalDeletedResources += len(response.RemovedResources) 197 | } 198 | 199 | if totalAddedResources+totalDeletedResources == 0 { 200 | return fmt.Errorf("empty response") 201 | } 202 | 203 | // Contains the set of resource names that wildcard watchers should be notified of. Only set if any 204 | // wildcard watchers are registered. 205 | var wildcardUpdates utils.Set[string] 206 | if h.wildcardSubscription != nil { 207 | wildcardUpdates = make(utils.Set[string], totalAddedResources+totalDeletedResources) 208 | } 209 | // Contains the set of resource names received. Only set if this is the first set of responses for 210 | // the stream, as it is used to determine whether any resources were deleted while the client was 211 | // disconnected. For example, suppose resources foo and bar are present on the ADS server. If a 212 | // wildcard watcher is registered, it will initially receive updates for those two resources. Then 213 | // the client disconnects, reconnects and resubmits its wildcard subscription. If bar was deleted 214 | // during the disconnect, the server will only send back an update for foo, but never an explicit 215 | // deletion for bar. This set is therefore used to compare against h.resources, i.e. the set 216 | // known/previously received resources to see if wildcard and glob collection watchers need to be 217 | // notified of any deletions. 218 | var receivedResources utils.Set[string] 219 | if isFirst { 220 | receivedResources = make(utils.Set[string], totalAddedResources) 221 | } 222 | 223 | globUpdates := make(map[*globSubscription[T]]utils.Set[string]) 224 | 225 | for name, r := range iterateResources(responses) { 226 | sub := h.subscriptions[name] 227 | 228 | var globSub *globSubscription[T] 229 | gcURL, gcResourceName, err := ads.ParseGlobCollectionURN[T](name) 230 | if err == nil { 231 | globSub = h.globSubscriptions[gcURL] 232 | } 233 | 234 | if sub == nil && globSub == nil && h.wildcardSubscription == nil { 235 | addErr(fmt.Errorf("not subscribed to resource %q", name)) 236 | continue 237 | } 238 | 239 | var resource *ads.Resource[T] 240 | if r != nil { 241 | resource, err = ads.UnmarshalRawResource[T](r) 242 | if err != nil { 243 | addErr(err) 244 | continue 245 | } 246 | if isFirst { 247 | receivedResources.Add(name) 248 | } 249 | } 250 | 251 | updated := h.setResource(name, resource) 252 | 253 | if sub != nil && (!sub.initialized || updated) { 254 | sub.initialized = true 255 | notifyWatchers(sub, h.resolve(func(yield func(string) bool) { yield(name) })) 256 | } 257 | 258 | if globSub != nil { 259 | updates := GetNestedMap(globUpdates, globSub) 260 | if resource == nil && gcResourceName == ads.WildcardSubscription { 261 | maps.Copy(updates, globSub.entries) 262 | clear(globSub.entries) 263 | continue 264 | } else if !globSub.initialized || updated { 265 | updates.Add(name) 266 | if resource != nil { 267 | globSub.entries.Add(name) 268 | } else { 269 | globSub.entries.Remove(name) 270 | } 271 | } 272 | } 273 | 274 | if h.wildcardSubscription != nil && (!h.wildcardSubscription.initialized || updated) { 275 | wildcardUpdates.Add(name) 276 | } 277 | } 278 | 279 | if isFirst { 280 | for name := range h.resources { 281 | if _, ok := receivedResources[name]; !ok { 282 | delete(h.resources, name) 283 | if h.wildcardSubscription != nil { 284 | wildcardUpdates.Add(name) 285 | } 286 | } 287 | } 288 | } 289 | 290 | if h.wildcardSubscription != nil { 291 | h.wildcardSubscription.initialized = true 292 | 293 | if len(wildcardUpdates) > 0 { 294 | notifyWatchers(h.wildcardSubscription, h.resolve(wildcardUpdates.Values())) 295 | } 296 | } 297 | 298 | for globSub, updates := range globUpdates { 299 | globSub.initialized = true 300 | if len(updates) > 0 { 301 | notifyWatchers(&globSub.subscription, h.resolve(updates.Values())) 302 | } 303 | } 304 | 305 | return errors.Join(errs...) 306 | } 307 | 308 | // iterateResources returns an [iter.Seq2] that iterates over all the resources in the given 309 | // response. If the [ads.RawResource] is nil, the resource is being deleted. 310 | func iterateResources(responses []*ads.DeltaDiscoveryResponse) iter.Seq2[string, *ads.RawResource] { 311 | return func(yield func(string, *ads.RawResource) bool) { 312 | for _, res := range responses { 313 | for _, r := range res.Resources { 314 | if !yield(r.Name, r) { 315 | return 316 | } 317 | } 318 | for _, name := range res.RemovedResources { 319 | if !yield(name, nil) { 320 | return 321 | } 322 | } 323 | } 324 | } 325 | } 326 | 327 | func newSubscription[T proto.Message]() *subscription[T] { 328 | return &subscription[T]{ 329 | watchers: make(utils.Set[Watcher[T]]), 330 | } 331 | } 332 | 333 | type subscription[T proto.Message] struct { 334 | initialized bool 335 | watchers utils.Set[Watcher[T]] 336 | } 337 | 338 | type globSubscription[T proto.Message] struct { 339 | subscription[T] 340 | entries utils.Set[string] 341 | } 342 | 343 | // GetNestedMap is a utility function for nested maps. It will create the map at the given key if it 344 | // does not already exist, then returns the corresponding map. 345 | func GetNestedMap[K1, K2 comparable, V any, M ~map[K2]V](m map[K1]M, k K1) M { 346 | v, ok := m[k] 347 | if !ok { 348 | v = make(M) 349 | m[k] = v 350 | } 351 | return v 352 | } 353 | -------------------------------------------------------------------------------- /internal/server/handlers.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | "maps" 7 | "sync" 8 | "time" 9 | 10 | "github.com/linkedin/diderot/ads" 11 | "github.com/linkedin/diderot/internal/utils" 12 | serverstats "github.com/linkedin/diderot/stats/server" 13 | "golang.org/x/time/rate" 14 | "google.golang.org/protobuf/proto" 15 | ) 16 | 17 | // SendBufferSizeEstimator is a copy of the interface in the root package, to avoid import cycles. 18 | type SendBufferSizeEstimator interface { 19 | EstimateSubscriptionSize(streamCtx context.Context, typeURL string, resourceNamesSubscribe []string) int 20 | } 21 | 22 | // BatchSubscriptionHandler is an extension of the SubscriptionHandler interface in the root package 23 | // which allows a handler to be notified that a batch of calls to Notify is about to be received 24 | // (StartNotificationBatch). The batch of notifications should not be sent to the client until all 25 | // notifications for that batch have been received (EndNotificationBatch). Start and End will never 26 | // be invoked out of order, i.e. there will never be a call to EndNotificationBatch without a call to 27 | // StartNotificationBatch immediately preceding it. However, SubscriptionHandler.Notify can be 28 | // invoked at any point. 29 | type BatchSubscriptionHandler interface { 30 | StartNotificationBatch(map[string]string, int) 31 | ads.RawSubscriptionHandler 32 | EndNotificationBatch() 33 | } 34 | 35 | // sendBuffer is an alias for the map type used by the handler to accumulate pending resource updates 36 | // before sending them to the client. 37 | type sendBuffer map[string]serverstats.SentResource 38 | 39 | func newHandler( 40 | ctx context.Context, 41 | typeURL string, 42 | granularLimiter handlerLimiter, 43 | globalLimiter handlerLimiter, 44 | statsHandler serverstats.Handler, 45 | ignoreDeletes bool, 46 | send func(entries sendBuffer) error, 47 | ) *handler { 48 | h := &handler{ 49 | typeURL: typeURL, 50 | granularLimiter: granularLimiter, 51 | globalLimiter: globalLimiter, 52 | statsHandler: statsHandler, 53 | ctx: ctx, 54 | ignoreDeletes: ignoreDeletes, 55 | send: send, 56 | immediateNotificationReceived: newNotifyOnceChan(), 57 | notificationReceived: newNotifyOnceChan(), 58 | } 59 | go h.loop() 60 | return h 61 | } 62 | 63 | func newNotifyOnceChan() notifyOnceChan { 64 | return make(chan struct{}, 1) 65 | } 66 | 67 | // notifyOnceChan is a resettable chan that only receives a notification once. It is exclusively 68 | // meant to be used by handler. All methods should be invoked while holding the corresponding 69 | // handler.lock. 70 | type notifyOnceChan chan struct{} 71 | 72 | // notify notifies the channel using a non-blocking send 73 | func (ch notifyOnceChan) notify() { 74 | select { 75 | case ch <- struct{}{}: 76 | default: 77 | } 78 | } 79 | 80 | // reset ensures the channel has no pending notifications in case they were never read (this can 81 | // happen if a notification comes in after the granular rate limit clears but before the 82 | // corresponding handler.lock is acquired). 83 | func (ch notifyOnceChan) reset() { 84 | select { 85 | // clear the channel if it has a pending notification. This is required since 86 | // immediateNotificationReceived can be notified _after_ the granular limit clears. If it isn't 87 | // cleared during the reset, the loop will read from it and incorrectly detect an immediate 88 | // notification. 89 | case <-ch: 90 | // otherwise return immediately if the channel is empty 91 | default: 92 | } 93 | } 94 | 95 | var sendBufferPool = sync.Pool{New: func() any { return make(sendBuffer) }} 96 | 97 | // handler implements the BatchSubscriptionHandler interface using a backing map to aggregate updates 98 | // as they come in, and flushing them out, according to when the limiter permits it. 99 | type handler struct { 100 | typeURL string 101 | granularLimiter handlerLimiter 102 | globalLimiter handlerLimiter 103 | statsHandler serverstats.Handler 104 | lock sync.Mutex 105 | ctx context.Context 106 | ignoreDeletes bool 107 | send func(entries sendBuffer) error 108 | 109 | entries sendBuffer 110 | 111 | // The following notifyOnceChan instances are the signaling mechanism between loop and Notify. Calls 112 | // to Notify will first invoke notifyOnceChan.notify on immediateNotificationReceived based on the 113 | // contents of the subscription metadata, then call notify on notificationReceived. loop waits on the 114 | // channel that backs notificationReceived to be signaled and once the first notification is 115 | // received, waits for the global rate limit to clear. This allows updates to keep accumulating. It 116 | // then checks whether immediateNotificationReceived has been signaled, and if so skips the granular 117 | // rate limiter. Otherwise, it either waits for the granular rate limit to clear, or 118 | // immediateNotificationReceived to be signaled, whichever comes first. Only then does it invoke 119 | // swapEntries which resets notificationReceived, immediateNotificationReceived and entries to a 120 | // state where they can receive more notifications while, in the background, it invokes send with all 121 | // accumulated entries up to this point. Once send completes, it returns to waiting on 122 | // notificationReceived. All operations involving these channels will exit early if ctx is cancelled, 123 | // terminating the loop. 124 | immediateNotificationReceived notifyOnceChan 125 | notificationReceived notifyOnceChan 126 | 127 | // If batchStarted is true, Notify will not notify notificationReceived. This allows the batch to 128 | // complete before the response is sent, minimizing the number of responses. 129 | batchStarted bool 130 | 131 | // initialResourceVersions is a map of resource names to their initial versions. 132 | // this informs the server of the versions of the resources the xDS client knows of. 133 | initialResourceVersions map[string]*initialResourceVersion 134 | } 135 | 136 | type initialResourceVersion struct { 137 | // initial version of the resource, which the xDS client has seen. 138 | version string 139 | // received flag indicates if the resource has been received from the server and skipped from the response 140 | // being sent. we are maintaining this flag to differentiate between the resource which is deleted on cache and 141 | // the resource which is not updated since client has last seen it. 142 | received bool 143 | } 144 | 145 | // swapEntries grabs the lock then swaps the entries map to a nil map. It resets notificationReceived 146 | // and immediateNotificationReceived, and returns original entries map that was swapped. 147 | func (h *handler) swapEntries() sendBuffer { 148 | h.lock.Lock() 149 | defer h.lock.Unlock() 150 | entries := h.entries 151 | h.entries = nil 152 | h.notificationReceived.reset() 153 | h.immediateNotificationReceived.reset() 154 | return entries 155 | } 156 | 157 | func (h *handler) loop() { 158 | for { 159 | select { 160 | case <-h.ctx.Done(): 161 | return 162 | case <-h.notificationReceived: 163 | // Always wait for the global rate limiter to clear 164 | if waitForGlobalLimiter(h.ctx, h.globalLimiter, h.statsHandler) != nil { 165 | return 166 | } 167 | // Wait for the granular rate limiter 168 | if h.waitForGranularLimiterOrShortCircuit() != nil { 169 | return 170 | } 171 | } 172 | 173 | entries := h.swapEntries() 174 | 175 | var start time.Time 176 | if h.statsHandler != nil { 177 | start = time.Now() 178 | } 179 | 180 | err := h.send(entries) 181 | 182 | if h.statsHandler != nil { 183 | h.statsHandler.HandleServerEvent(h.ctx, &serverstats.ResponseSent{ 184 | TypeURL: h.typeURL, 185 | Resources: entries, 186 | Duration: time.Since(start), 187 | }) 188 | } 189 | 190 | // Return the used map to the pool after clearing it. 191 | clear(entries) 192 | sendBufferPool.Put(entries) 193 | 194 | if err != nil { 195 | return 196 | } 197 | } 198 | } 199 | 200 | func waitForGlobalLimiter( 201 | ctx context.Context, 202 | globalLimiter handlerLimiter, 203 | statsHandler serverstats.Handler, 204 | ) error { 205 | if statsHandler != nil { 206 | start := time.Now() 207 | defer func() { 208 | statsHandler.HandleServerEvent(ctx, &serverstats.TimeInGlobalRateLimiter{Duration: time.Since(start)}) 209 | }() 210 | } 211 | 212 | reservation, cancel := globalLimiter.reserve() 213 | defer cancel() 214 | 215 | select { 216 | case <-ctx.Done(): 217 | return ctx.Err() 218 | case <-reservation: 219 | return nil 220 | } 221 | } 222 | 223 | // waitForGranularLimiterOrShortCircuit will acquire a reservation from granularLimiter and wait on 224 | // it, but will short circuit the reservation if an immediate notification is received (or if ctx is 225 | // canceled). 226 | func (h *handler) waitForGranularLimiterOrShortCircuit() error { 227 | reservation, cancel := h.granularLimiter.reserve() 228 | defer cancel() 229 | 230 | for { 231 | select { 232 | case <-h.ctx.Done(): 233 | return h.ctx.Err() 234 | case <-h.immediateNotificationReceived: 235 | // If an immediate notification is received, immediately return instead of waiting for the granular 236 | // limit. Without this, a bootstrapping client may be forced to wait for the initial versions of the 237 | // resources it is interested in. The purpose of the rate limiter is to avoid overwhelming the 238 | // client, however if the client is healthy enough to request new resources then those resources 239 | // should be sent without delay. Do note, however, that the responses will still always be rate 240 | // limited by the global limiter. 241 | return nil 242 | case <-reservation: 243 | // Otherwise, wait for the granular rate limit to clear. 244 | return nil 245 | } 246 | } 247 | } 248 | 249 | func (h *handler) Notify(name string, r *ads.RawResource, metadata ads.SubscriptionMetadata) { 250 | h.lock.Lock() 251 | defer h.lock.Unlock() 252 | 253 | if metadata.CachedAt.IsZero() && h.statsHandler != nil { 254 | h.statsHandler.HandleServerEvent(h.ctx, &serverstats.UnknownResourceRequested{ 255 | TypeURL: h.typeURL, 256 | ResourceName: name, 257 | }) 258 | } 259 | 260 | if r == nil && h.ignoreDeletes { 261 | return 262 | } 263 | 264 | if h.entries == nil { 265 | h.entries = sendBufferPool.Get().(sendBuffer) 266 | } 267 | 268 | if h.handleMatchFromIRV(name, r) { 269 | return 270 | } 271 | 272 | h.entries[name] = serverstats.SentResource{ 273 | Resource: r, 274 | Metadata: metadata, 275 | QueuedAt: time.Now(), 276 | } 277 | 278 | if r != nil && metadata.GlobCollectionURL != "" { 279 | // When a glob collection is empty, it is signaled to the client with a corresponding deletion of 280 | // that collection's name. For example, if a collection Foo/* becomes empty (or the client subscribed 281 | // to a collection that does not exist), it will receive a deletion notification for Foo/*. There is 282 | // an edge case in the following scenario: suppose a collection currently has some resource Foo/A in 283 | // it. Upon subscribing, the handler will be notified that the resource exists. Foo/A is then 284 | // removed, so the handler receives a notification that Foo/A is removed, and because Foo/* is empty 285 | // it also receives a corresponding notification. But, soon after, resource Foo/B is created, 286 | // reviving Foo/* and the handler receives the corresponding notification for Foo/B. At this point, 287 | // if the response were to be sent as-is, it would contain both the creation of Foo/B and the 288 | // deletion of Foo/*. Depending on the order in which the client processes the response's contents, 289 | // it may ignore Foo/B altogether. To avoid this, always clear out the deletion of Foo/* when a 290 | // notification for the creation of an entry within Foo/* is received. 291 | delete(h.entries, metadata.GlobCollectionURL) 292 | } 293 | 294 | if !h.batchStarted { 295 | h.notificationReceived.notify() 296 | } 297 | } 298 | 299 | func (h *handler) ResourceMarshalError(name string, resource proto.Message, err error) { 300 | if h.statsHandler != nil { 301 | h.statsHandler.HandleServerEvent(h.ctx, &serverstats.ResourceMarshalError{ 302 | ResourceName: name, 303 | Resource: resource, 304 | Err: err, 305 | }) 306 | } 307 | } 308 | 309 | func (h *handler) StartNotificationBatch(initialResourceVersions map[string]string, estimatedSize int) { 310 | h.lock.Lock() 311 | defer h.lock.Unlock() 312 | 313 | if len(initialResourceVersions) > 0 { 314 | h.initialResourceVersions = make(map[string]*initialResourceVersion, len(initialResourceVersions)) 315 | 316 | // setting the initial version of resources to filter out unchanged resources. 317 | for name, version := range initialResourceVersions { 318 | h.initialResourceVersions[name] = &initialResourceVersion{version: version} 319 | } 320 | } else if estimatedSize > 0 { 321 | // Only preallocate the send buffer if the initialResourceVersions is empty. Otherwise, it's very 322 | // likely that the buffer will be underutilized and waste resources. 323 | prevBuf := h.entries 324 | h.entries = make(sendBuffer, estimatedSize+len(prevBuf)) 325 | 326 | if len(prevBuf) > 0 { 327 | // Is possible that some notifications were already pending, so ensure those are not lost before 328 | // returning the to the pool. 329 | maps.Copy(h.entries, prevBuf) 330 | clear(prevBuf) 331 | sendBufferPool.Put(prevBuf) 332 | } 333 | } 334 | h.batchStarted = true 335 | } 336 | 337 | func (h *handler) EndNotificationBatch() { 338 | h.lock.Lock() 339 | defer h.lock.Unlock() 340 | 341 | h.handleDeletionsFromIRV() 342 | h.batchStarted = false 343 | // Resetting the initial resource versions to nil, we need to make sure to handle IRV 344 | // for each incoming request independently 345 | h.initialResourceVersions = nil 346 | if len(h.entries) > 0 { 347 | h.immediateNotificationReceived.notify() 348 | h.notificationReceived.notify() 349 | } 350 | } 351 | 352 | // handleDeletionsFromIRV processes resources that are known to the client but are no longer present on the server. 353 | // This indicates that the resource has been deleted and the client is unaware of this (e.g. in a re-connect scenario) 354 | // The method update the entries to nil for such resources. 355 | func (h *handler) handleDeletionsFromIRV() { 356 | for name, irv := range h.initialResourceVersions { 357 | if _, ok := h.entries[name]; !ok && !irv.received { 358 | slog.Debug("Resource no longer exists on the server but is still present on the client. "+ 359 | "Explicitly marking the resource for deletion.", "resourceName", name) 360 | 361 | // in some corner case, when last resource is deleted. and there is no subscribed resource present in cache, 362 | // entries might be nil, so we need to allocate a new map. 363 | if h.entries == nil { 364 | h.entries = sendBufferPool.Get().(sendBuffer) 365 | } 366 | h.entries[name] = serverstats.SentResource{} 367 | } 368 | } 369 | } 370 | 371 | // handleMatchFromIRV checks if the given resource matches the initial resource versions (IRV). 372 | func (h *handler) handleMatchFromIRV(name string, r *ads.RawResource) bool { 373 | if res, ok := h.initialResourceVersions[name]; ok { 374 | if r != nil && res.version == r.Version { 375 | slog.Debug( 376 | "Resource version matches with IRV from client, skipping this resource", 377 | "resourceName", name, 378 | "version", res.version, 379 | ) 380 | res.received = true 381 | 382 | if h.statsHandler != nil { 383 | h.statsHandler.HandleServerEvent(h.ctx, &serverstats.IRVMatchedResource{ 384 | ResourceName: name, 385 | Resource: r, 386 | }) 387 | } 388 | return true 389 | } 390 | } 391 | return false 392 | } 393 | 394 | func NewSotWHandler( 395 | ctx context.Context, 396 | granularLimiter *rate.Limiter, 397 | globalLimiter *rate.Limiter, 398 | statsHandler serverstats.Handler, 399 | typeURL string, 400 | send func(res *ads.SotWDiscoveryResponse) error, 401 | ) BatchSubscriptionHandler { 402 | return newSotWHandler( 403 | ctx, 404 | (*rateLimiterWrapper)(granularLimiter), 405 | (*rateLimiterWrapper)(globalLimiter), 406 | statsHandler, 407 | typeURL, 408 | send, 409 | ) 410 | } 411 | 412 | func newSotWHandler( 413 | ctx context.Context, 414 | granularLimiter handlerLimiter, 415 | globalLimiter handlerLimiter, 416 | statsHandler serverstats.Handler, 417 | typeURL string, 418 | send func(res *ads.SotWDiscoveryResponse) error, 419 | ) *handler { 420 | isPseudoDeltaSotW := utils.IsPseudoDeltaSotW(typeURL) 421 | var looper func(resources sendBuffer) error 422 | if isPseudoDeltaSotW { 423 | looper = func(entries sendBuffer) error { 424 | versions := map[string]string{} 425 | 426 | for name, e := range entries { 427 | versions[name] = e.Resource.Version 428 | } 429 | 430 | res := &ads.SotWDiscoveryResponse{ 431 | TypeUrl: typeURL, 432 | Nonce: utils.NewNonce(0), 433 | } 434 | for _, e := range entries { 435 | res.Resources = append(res.Resources, e.Resource.Resource) 436 | } 437 | res.VersionInfo = utils.MapToProto(versions) 438 | return send(res) 439 | } 440 | } else { 441 | allResources := sendBuffer{} 442 | versions := map[string]string{} 443 | 444 | looper = func(resources sendBuffer) error { 445 | for name, r := range resources { 446 | if r.Resource != nil { 447 | allResources[name] = r 448 | versions[name] = r.Resource.Version 449 | } else { 450 | delete(allResources, name) 451 | delete(versions, name) 452 | } 453 | } 454 | 455 | res := &ads.SotWDiscoveryResponse{ 456 | TypeUrl: typeURL, 457 | Nonce: utils.NewNonce(0), 458 | } 459 | for _, r := range allResources { 460 | res.Resources = append(res.Resources, r.Resource.Resource) 461 | } 462 | res.VersionInfo = utils.MapToProto(versions) 463 | return send(res) 464 | } 465 | } 466 | 467 | return newHandler( 468 | ctx, 469 | typeURL, 470 | granularLimiter, 471 | globalLimiter, 472 | statsHandler, 473 | isPseudoDeltaSotW, 474 | looper, 475 | ) 476 | } 477 | -------------------------------------------------------------------------------- /internal/server/handlers_bench_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "sync" 7 | "testing" 8 | 9 | "github.com/linkedin/diderot/ads" 10 | "github.com/linkedin/diderot/testutils" 11 | "google.golang.org/protobuf/types/known/anypb" 12 | ) 13 | 14 | func benchmarkHandlers(tb testing.TB, count, subscriptions int) { 15 | valueNames := make([]string, subscriptions) 16 | for i := range valueNames { 17 | valueNames[i] = strconv.Itoa(i) 18 | } 19 | 20 | ctx := testutils.Context(tb) 21 | 22 | var finished sync.WaitGroup 23 | finished.Add(subscriptions) 24 | const finalVersion = "done" 25 | h := newHandler( 26 | ctx, 27 | AnyTypeURL, 28 | NoopLimiter{}, 29 | NoopLimiter{}, 30 | new(customStatsHandler), 31 | false, 32 | func(resources sendBuffer) error { 33 | for _, r := range resources { 34 | if r.Resource.Version == finalVersion { 35 | finished.Done() 36 | } 37 | } 38 | return nil 39 | }, 40 | ) 41 | 42 | for _, name := range valueNames { 43 | go func(name string) { 44 | resource := testutils.MustMarshal(tb, ads.NewResource(name, "0", new(anypb.Any))) 45 | for i := 0; i < count-1; i++ { 46 | h.Notify(name, resource, ads.SubscriptionMetadata{}) 47 | } 48 | h.Notify( 49 | name, 50 | &ads.RawResource{Name: name, Version: finalVersion, Resource: resource.Resource}, 51 | ads.SubscriptionMetadata{}, 52 | ) 53 | }(name) 54 | } 55 | finished.Wait() 56 | } 57 | 58 | var increments = []int{1, 10, 100, 1000, 10_000} 59 | 60 | func BenchmarkHandlers(b *testing.B) { 61 | for _, subscriptions := range increments { 62 | b.Run(fmt.Sprintf("%5d subs", subscriptions), func(b *testing.B) { 63 | benchmarkHandlers(b, b.N, subscriptions) 64 | }) 65 | } 66 | } 67 | 68 | func TestHandlers(t *testing.T) { 69 | benchmarkHandlers(t, 1000, 1000) 70 | } 71 | -------------------------------------------------------------------------------- /internal/server/handlers_delta.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "cmp" 5 | "context" 6 | "log/slog" 7 | "slices" 8 | "sync" 9 | 10 | "github.com/linkedin/diderot/ads" 11 | "github.com/linkedin/diderot/internal/utils" 12 | serverstats "github.com/linkedin/diderot/stats/server" 13 | "golang.org/x/time/rate" 14 | "google.golang.org/protobuf/proto" 15 | ) 16 | 17 | func NewDeltaHandler( 18 | ctx context.Context, 19 | granularLimiter *rate.Limiter, 20 | globalLimiter *rate.Limiter, 21 | statsHandler serverstats.Handler, 22 | maxChunkSize int, 23 | typeUrl string, 24 | send func(res *ads.DeltaDiscoveryResponse) error, 25 | ) BatchSubscriptionHandler { 26 | return newDeltaHandler( 27 | ctx, 28 | (*rateLimiterWrapper)(granularLimiter), 29 | (*rateLimiterWrapper)(globalLimiter), 30 | statsHandler, 31 | maxChunkSize, 32 | typeUrl, 33 | send, 34 | ) 35 | } 36 | 37 | func newDeltaHandler( 38 | ctx context.Context, 39 | granularLimiter handlerLimiter, 40 | globalLimiter handlerLimiter, 41 | statsHandler serverstats.Handler, 42 | maxChunkSize int, 43 | typeURL string, 44 | send func(res *ads.DeltaDiscoveryResponse) error, 45 | ) *handler { 46 | ds := &deltaSender{ 47 | ctx: ctx, 48 | typeURL: typeURL, 49 | maxChunkSize: maxChunkSize, 50 | statsHandler: statsHandler, 51 | minChunkSize: initialChunkSize(typeURL), 52 | } 53 | 54 | return newHandler( 55 | ctx, 56 | typeURL, 57 | granularLimiter, 58 | globalLimiter, 59 | statsHandler, 60 | false, 61 | func(entries sendBuffer) error { 62 | for i, chunk := range ds.chunk(entries) { 63 | if i > 0 { 64 | // Respect the global limiter in between chunks 65 | err := waitForGlobalLimiter(ctx, globalLimiter, statsHandler) 66 | if err != nil { 67 | return err 68 | } 69 | } 70 | err := send(chunk) 71 | if err != nil { 72 | return err 73 | } 74 | } 75 | return nil 76 | }, 77 | ) 78 | } 79 | 80 | type queuedResourceUpdate struct { 81 | Name string 82 | Size int 83 | } 84 | 85 | type deltaSender struct { 86 | ctx context.Context 87 | typeURL string 88 | statsHandler serverstats.Handler 89 | // The maximum size (in bytes) that a chunk can be. This is determined by the client as anything 90 | // larger than this size will cause the message to be dropped. 91 | maxChunkSize int 92 | 93 | // The minimum size an encoded chunk will serialize to, in bytes. Used to check whether a given 94 | // update can _ever_ be sent, and as the initial size of a chunk. Note that this value only depends 95 | // on utils.NonceLength and the length of typeURL. 96 | minChunkSize int 97 | } 98 | 99 | var queuedUpdatesPool = sync.Pool{} 100 | 101 | // newQueue creates a new []queuedResourceUpdate with at least enough capacity to hold the required 102 | // size. Note that this returns a pointer to a slice instead of a slice to avoid heap allocations. 103 | // This is the recommended way to use a [sync.Pool] with slices. 104 | func newQueue(size int) *[]queuedResourceUpdate { 105 | // Attempt to get a queue from the pool. If it returns nil, ok will be false meaning the pool was 106 | // empty. 107 | queue, ok := queuedUpdatesPool.Get().(*[]queuedResourceUpdate) 108 | if ok && cap(*queue) >= size { 109 | *queue = (*queue)[:0] 110 | } else { 111 | if ok { 112 | // Return the queue that was too short to the pool 113 | queuedUpdatesPool.Put(queue) 114 | } 115 | queue = new([]queuedResourceUpdate) 116 | *queue = make([]queuedResourceUpdate, 0, size) 117 | } 118 | return queue 119 | } 120 | 121 | func (ds *deltaSender) chunk(resourceUpdates sendBuffer) (chunks []*ads.DeltaDiscoveryResponse) { 122 | queuePtr := newQueue(len(resourceUpdates)) 123 | defer queuedUpdatesPool.Put(queuePtr) 124 | 125 | queue := *queuePtr 126 | for name, e := range resourceUpdates { 127 | queue = append(queue, queuedResourceUpdate{ 128 | Name: name, 129 | Size: encodedUpdateSize(name, e.Resource), 130 | }) 131 | } 132 | // Sort the updates in descending order 133 | slices.SortFunc(queue, func(a, b queuedResourceUpdate) int { 134 | return -cmp.Compare(a.Size, b.Size) 135 | }) 136 | 137 | // This nested loop builds the fewest possible chunks it can from the given resourceUpdates map. It 138 | // implements an approximation of the bin-packing algorithm called next-fit-decreasing bin-packing 139 | // https://en.wikipedia.org/wiki/Next-fit-decreasing_bin_packing 140 | idx := 0 141 | for idx < len(queue) { 142 | // This chunk will hold all the updates for this loop iteration 143 | chunk := ds.newChunk() 144 | chunkSize := proto.Size(chunk) 145 | 146 | for ; idx < len(queue); idx++ { 147 | update := queue[idx] 148 | r := resourceUpdates[update.Name] 149 | 150 | if ds.maxChunkSize > 0 { 151 | if ds.minChunkSize+update.Size > ds.maxChunkSize { 152 | // This condition only occurs if the update can never be sent, i.e. it is too large and will 153 | // always be dropped by the client. It should therefore be skipped altogether, but flagged 154 | // accordingly. 155 | if ds.statsHandler != nil { 156 | ds.statsHandler.HandleServerEvent(ds.ctx, &serverstats.ResourceOverMaxSize{ 157 | Resource: r.Resource, 158 | ResourceSize: update.Size, 159 | MaxResourceSize: ds.maxChunkSize, 160 | }) 161 | } 162 | slog.ErrorContext( 163 | ds.ctx, 164 | "Cannot send resource update because it is larger than configured max delta response size", 165 | "maxDeltaResponseSize", ds.maxChunkSize, 166 | "name", update.Name, 167 | "updateSize", update.Size, 168 | "resource", r, 169 | ) 170 | continue 171 | } 172 | if chunkSize+update.Size > ds.maxChunkSize { 173 | // This update it too large to be sent along with the current chunk, skip it for now and 174 | // attempt it in the next chunk. 175 | break 176 | } 177 | } 178 | 179 | if r.Resource != nil { 180 | chunk.Resources = append(chunk.Resources, r.Resource) 181 | } else { 182 | chunk.RemovedResources = append(chunk.RemovedResources, update.Name) 183 | } 184 | // Add the resource since it is small enough to be added to the chunk 185 | chunkSize += update.Size 186 | } 187 | 188 | chunks = append(chunks, chunk) 189 | } 190 | 191 | if len(chunks) > 1 { 192 | slog.WarnContext( 193 | ds.ctx, 194 | "Response exceeded max response size, sent in chunks", 195 | "chunks", len(chunks), 196 | "typeURL", ds.typeURL, 197 | "updates", len(queue), 198 | ) 199 | for i, c := range chunks { 200 | c.Nonce = utils.NewNonce(len(chunks) - i - 1) 201 | } 202 | } else { 203 | chunks[0].Nonce = utils.NewNonce(0) 204 | } 205 | 206 | return chunks 207 | } 208 | 209 | func (ds *deltaSender) newChunk() *ads.DeltaDiscoveryResponse { 210 | return &ads.DeltaDiscoveryResponse{ 211 | TypeUrl: ds.typeURL, 212 | } 213 | } 214 | 215 | const protobufSliceOverhead = 2 216 | 217 | func initialChunkSize(typeUrl string) int { 218 | return protobufSliceOverhead + len(typeUrl) + protobufSliceOverhead + utils.NonceLength 219 | } 220 | 221 | // encodedUpdateSize returns the amount of bytes it takes to encode the given update in an *ads.DeltaDiscoveryResponse. 222 | func encodedUpdateSize(name string, r *ads.RawResource) int { 223 | resourceSize := protobufSliceOverhead 224 | if r != nil { 225 | resourceSize += proto.Size(r) 226 | } else { 227 | resourceSize += len(name) 228 | } 229 | return resourceSize 230 | } 231 | -------------------------------------------------------------------------------- /internal/server/handlers_delta_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "context" 5 | "slices" 6 | "strings" 7 | "sync/atomic" 8 | "testing" 9 | 10 | "github.com/linkedin/diderot/ads" 11 | "github.com/linkedin/diderot/internal/utils" 12 | serverstats "github.com/linkedin/diderot/stats/server" 13 | "github.com/linkedin/diderot/testutils" 14 | "github.com/stretchr/testify/require" 15 | "google.golang.org/protobuf/proto" 16 | "google.golang.org/protobuf/types/known/wrapperspb" 17 | ) 18 | 19 | func TestDeltaHandler(t *testing.T) { 20 | l := NewTestHandlerLimiter() 21 | 22 | typeURL := utils.GetTypeURL[*wrapperspb.BoolValue]() 23 | var lastRes *ads.DeltaDiscoveryResponse 24 | h := newDeltaHandler( 25 | testutils.Context(t), 26 | NoopLimiter{}, 27 | l, 28 | new(customStatsHandler), 29 | 0, 30 | typeURL, 31 | func(res *ads.DeltaDiscoveryResponse) error { 32 | defer l.Done() 33 | lastRes = res 34 | return nil 35 | }, 36 | ) 37 | 38 | const foo, bar = "foo", "bar" 39 | h.Notify(foo, nil, ignoredMetadata) 40 | r := new(ads.RawResource) 41 | h.Notify(bar, r, ignoredMetadata) 42 | 43 | l.Release() 44 | 45 | require.Equal(t, typeURL, lastRes.TypeUrl) 46 | require.Len(t, lastRes.Resources, 1) 47 | require.Equal(t, r, lastRes.Resources[0]) 48 | require.Equal(t, []string{foo}, lastRes.RemovedResources) 49 | } 50 | 51 | func TestEncodedUpdateSize(t *testing.T) { 52 | foo := testutils.MustMarshal(t, ads.NewResource("foo", "42", new(wrapperspb.Int64Value))) 53 | notFoo := testutils.MustMarshal(t, ads.NewResource("notFoo", "27", new(wrapperspb.Int64Value))) 54 | 55 | checkSize := func(t *testing.T, msg proto.Message, size int) { 56 | require.Equal(t, size, proto.Size(msg)) 57 | data, err := proto.Marshal(msg) 58 | require.NoError(t, err) 59 | require.Len(t, data, size) 60 | } 61 | 62 | ds := &deltaSender{typeURL: utils.GetTypeURL[*wrapperspb.StringValue]()} 63 | 64 | t.Run("add", func(t *testing.T) { 65 | res := ds.newChunk() 66 | responseSize := proto.Size(res) 67 | 68 | res.Resources = append(res.Resources, foo) 69 | responseSize += encodedUpdateSize(foo.Name, foo) 70 | checkSize(t, res, responseSize) 71 | 72 | res.Resources = append(res.Resources, notFoo) 73 | responseSize += encodedUpdateSize(notFoo.Name, notFoo) 74 | checkSize(t, res, responseSize) 75 | }) 76 | t.Run("remove", func(t *testing.T) { 77 | res := ds.newChunk() 78 | responseSize := proto.Size(res) 79 | 80 | res.RemovedResources = append(res.RemovedResources, foo.Name) 81 | responseSize += encodedUpdateSize(foo.Name, nil) 82 | checkSize(t, res, responseSize) 83 | 84 | res.RemovedResources = append(res.RemovedResources, notFoo.Name) 85 | responseSize += encodedUpdateSize(notFoo.Name, nil) 86 | checkSize(t, res, responseSize) 87 | }) 88 | t.Run("add and remove", func(t *testing.T) { 89 | res := ds.newChunk() 90 | responseSize := proto.Size(res) 91 | 92 | res.Resources = append(res.Resources, foo) 93 | responseSize += encodedUpdateSize(foo.Name, foo) 94 | checkSize(t, res, responseSize) 95 | 96 | res.RemovedResources = append(res.RemovedResources, notFoo.Name) 97 | responseSize += encodedUpdateSize(notFoo.Name, nil) 98 | checkSize(t, res, responseSize) 99 | }) 100 | } 101 | 102 | func TestInitialChunkSize(t *testing.T) { 103 | typeURL := utils.GetTypeURL[*wrapperspb.StringValue]() 104 | require.Equal(t, proto.Size(&ads.DeltaDiscoveryResponse{ 105 | TypeUrl: typeURL, 106 | Nonce: utils.NewNonce(0), 107 | }), initialChunkSize(typeURL)) 108 | } 109 | 110 | func TestDeltaHandlerChunking(t *testing.T) { 111 | foo := testutils.MustMarshal(t, ads.NewResource("foo", "0", wrapperspb.String("foo"))) 112 | bar := testutils.MustMarshal(t, ads.NewResource("bar", "0", wrapperspb.String("bar"))) 113 | require.Equal(t, proto.Size(foo), proto.Size(bar)) 114 | resourceSize := proto.Size(foo) 115 | 116 | typeURL := utils.GetTypeURL[*wrapperspb.StringValue]() 117 | statsHandler := new(customStatsHandler) 118 | ds := &deltaSender{ 119 | typeURL: typeURL, 120 | statsHandler: statsHandler, 121 | maxChunkSize: initialChunkSize(typeURL) + protobufSliceOverhead + resourceSize, 122 | minChunkSize: initialChunkSize(typeURL), 123 | } 124 | 125 | getSentResponses := func(resources sendBuffer, expectedChunks int) []*ads.DeltaDiscoveryResponse { 126 | responses := ds.chunk(resources) 127 | require.Len(t, responses, expectedChunks) 128 | expectedRemainingChunks := 0 129 | for _, res := range slices.Backward(responses) { 130 | remaining, err := ads.ParseRemainingChunksFromNonce(res.Nonce) 131 | require.NoError(t, err) 132 | require.Equal(t, expectedRemainingChunks, remaining) 133 | expectedRemainingChunks++ 134 | } 135 | return responses 136 | } 137 | 138 | sentResponses := getSentResponses(sendBuffer{ 139 | foo.Name: serverstats.SentResource{Resource: foo}, 140 | bar.Name: serverstats.SentResource{Resource: bar}, 141 | }, 2) 142 | require.Equal(t, len(sentResponses[0].Resources), 1) 143 | require.Equal(t, len(sentResponses[1].Resources), 1) 144 | response0 := sentResponses[0].Resources[0] 145 | response1 := sentResponses[1].Resources[0] 146 | 147 | if response0.Name == foo.Name { 148 | testutils.ProtoEquals(t, foo, response0) 149 | testutils.ProtoEquals(t, bar, response1) 150 | } else { 151 | testutils.ProtoEquals(t, bar, response0) 152 | testutils.ProtoEquals(t, foo, response1) 153 | } 154 | 155 | // Delete resources whose names are the same size as the resources to trip the chunker with the same conditions 156 | name1 := strings.Repeat("1", resourceSize) 157 | name2 := strings.Repeat("2", resourceSize) 158 | sentResponses = getSentResponses(sendBuffer{ 159 | name1: serverstats.SentResource{Resource: nil}, 160 | name2: serverstats.SentResource{Resource: nil}, 161 | }, 2) 162 | require.Equal(t, len(sentResponses[0].RemovedResources), 1) 163 | require.Equal(t, len(sentResponses[1].RemovedResources), 1) 164 | require.ElementsMatch(t, 165 | []string{name1, name2}, 166 | []string{sentResponses[0].RemovedResources[0], sentResponses[1].RemovedResources[0]}, 167 | ) 168 | 169 | small1, small2, small3 := "a", "b", "c" 170 | wayTooBig := strings.Repeat("3", 10*resourceSize) 171 | 172 | sentResponses = getSentResponses(sendBuffer{ 173 | small1: serverstats.SentResource{Resource: nil}, 174 | small2: serverstats.SentResource{Resource: nil}, 175 | small3: serverstats.SentResource{Resource: nil}, 176 | wayTooBig: serverstats.SentResource{Resource: nil}, 177 | }, 1) 178 | require.Equal(t, len(sentResponses[0].RemovedResources), 3) 179 | require.ElementsMatch(t, []string{small1, small2, small3}, sentResponses[0].RemovedResources) 180 | require.Equal(t, int64(1), statsHandler.DeltaResourcesOverMaxSize.Load()) 181 | } 182 | 183 | type customStatsHandler struct { 184 | DeltaResourcesOverMaxSize atomic.Int64 `metric:",counter"` 185 | } 186 | 187 | func (h *customStatsHandler) HandleServerEvent(ctx context.Context, event serverstats.Event) { 188 | if _, ok := event.(*serverstats.ResourceOverMaxSize); ok { 189 | h.DeltaResourcesOverMaxSize.Add(1) 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /internal/server/handlers_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "maps" 5 | "strconv" 6 | "sync" 7 | "sync/atomic" 8 | "testing" 9 | "time" 10 | 11 | gocmp "github.com/google/go-cmp/cmp" 12 | gocmpopts "github.com/google/go-cmp/cmp/cmpopts" 13 | "github.com/linkedin/diderot/ads" 14 | "github.com/linkedin/diderot/internal/utils" 15 | serverstats "github.com/linkedin/diderot/stats/server" 16 | "github.com/linkedin/diderot/testutils" 17 | "github.com/stretchr/testify/require" 18 | "google.golang.org/protobuf/testing/protocmp" 19 | "google.golang.org/protobuf/types/known/anypb" 20 | "google.golang.org/protobuf/types/known/wrapperspb" 21 | ) 22 | 23 | const AnyTypeURL = "type.googleapis.com/google.protobuf.Any" 24 | 25 | func checkSendBuffer(t *testing.T, expected, actual sendBuffer) { 26 | opts := []gocmp.Option{gocmpopts.IgnoreFields(serverstats.SentResource{}, "QueuedAt"), protocmp.Transform()} 27 | if !gocmp.Equal(expected, actual, opts...) { 28 | require.Equal(t, expected, actual) 29 | } 30 | } 31 | 32 | // TestHandlerDebounce checks the following: 33 | // 1. That the handler does not invoke send as long as the debouncer has not allowed it to. 34 | // 2. That updates that come in while send is being invoked do not get missed. 35 | // 3. That if multiple updates for the same resource come in, only the latest one is respected. 36 | func TestHandlerDebounce(t *testing.T) { 37 | var released atomic.Bool 38 | l := NewTestHandlerLimiter() 39 | 40 | var enterSendWg, continueSendWg sync.WaitGroup 41 | continueSendWg.Add(1) 42 | 43 | actualResources := make(sendBuffer) 44 | 45 | h := newHandler( 46 | testutils.Context(t), 47 | AnyTypeURL, 48 | NoopLimiter{}, 49 | l, 50 | new(customStatsHandler), 51 | false, 52 | func(resources sendBuffer) error { 53 | require.True(t, released.Swap(false), "send invoked without being released") 54 | require.NotEmpty(t, resources) 55 | enterSendWg.Done() 56 | continueSendWg.Wait() 57 | defer l.Done() 58 | for k, e := range resources { 59 | actualResources[k] = e 60 | } 61 | return nil 62 | }, 63 | ) 64 | 65 | // declare the various times upfront and ensure they are all unique, which will allow validating the interactions 66 | // with the handler 67 | var ( 68 | fooSubscribedAt = time.Now() 69 | fooCreateMetadata = ads.SubscriptionMetadata{ 70 | SubscribedAt: fooSubscribedAt, 71 | ModifiedAt: fooSubscribedAt.Add(2 * time.Hour), 72 | CachedAt: fooSubscribedAt.Add(3 * time.Hour), 73 | } 74 | fooDeleteMetadata = ads.SubscriptionMetadata{ 75 | SubscribedAt: fooSubscribedAt, 76 | ModifiedAt: time.Time{}, 77 | CachedAt: fooSubscribedAt.Add(4 * time.Hour), 78 | } 79 | 80 | barCreateMetadata = ads.SubscriptionMetadata{ 81 | SubscribedAt: fooSubscribedAt.Add(5 * time.Hour), 82 | ModifiedAt: fooSubscribedAt.Add(6 * time.Hour), 83 | CachedAt: fooSubscribedAt.Add(7 * time.Hour), 84 | } 85 | ) 86 | 87 | const foo, bar = "foo", "bar" 88 | barR := new(ads.RawResource) 89 | 90 | h.Notify(foo, new(ads.RawResource), fooCreateMetadata) 91 | h.Notify(foo, nil, fooDeleteMetadata) 92 | 93 | enterSendWg.Add(1) 94 | go func() { 95 | enterSendWg.Wait() 96 | h.Notify(bar, barR, barCreateMetadata) 97 | continueSendWg.Done() 98 | }() 99 | 100 | released.Store(true) 101 | l.Release() 102 | checkSendBuffer(t, sendBuffer{ 103 | foo: serverstats.SentResource{ 104 | Resource: nil, 105 | Metadata: fooDeleteMetadata, 106 | }, 107 | }, actualResources) 108 | delete(actualResources, foo) 109 | 110 | enterSendWg.Add(1) 111 | released.Store(true) 112 | l.Release() 113 | checkSendBuffer(t, sendBuffer{ 114 | bar: serverstats.SentResource{ 115 | Resource: barR, 116 | Metadata: barCreateMetadata, 117 | }, 118 | }, actualResources) 119 | } 120 | 121 | func TestHandlerBatching(t *testing.T) { 122 | var released atomic.Bool 123 | ch := make(chan sendBuffer) 124 | granular := NewTestHandlerLimiter() 125 | h := newHandler( 126 | testutils.Context(t), 127 | AnyTypeURL, 128 | granular, 129 | NoopLimiter{}, 130 | new(customStatsHandler), 131 | false, 132 | func(resources sendBuffer) error { 133 | // Double check that send isn't invoked before it's expected 134 | if !released.Load() { 135 | t.Fatalf("send invoked before release!") 136 | } 137 | ch <- maps.Clone(resources) 138 | return nil 139 | }, 140 | ) 141 | expectedEntries := make(sendBuffer) 142 | notify := func() { 143 | name := strconv.Itoa(len(expectedEntries)) 144 | h.Notify(name, nil, ads.SubscriptionMetadata{}) 145 | expectedEntries[name] = serverstats.SentResource{Resource: nil} 146 | } 147 | 148 | h.StartNotificationBatch(nil, 0) 149 | notify() 150 | 151 | for i := 0; i < 100; i++ { 152 | notify() 153 | } 154 | released.Store(true) 155 | h.EndNotificationBatch() 156 | 157 | checkSendBuffer(t, expectedEntries, <-ch) 158 | 159 | released.Store(false) 160 | 161 | clear(expectedEntries) 162 | notify() 163 | granular.WaitForReserve() 164 | 165 | released.Store(true) 166 | // Check that EndNotificationBatch skips the granular limiter 167 | h.EndNotificationBatch() 168 | 169 | checkSendBuffer(t, expectedEntries, <-ch) 170 | } 171 | 172 | func TestHandlerDoesNothingOnEmptyBatch(t *testing.T) { 173 | h := newHandler( 174 | testutils.Context(t), 175 | AnyTypeURL, 176 | // Make both limiters nil, if the handler interacts with them at all the test should fail 177 | nil, 178 | nil, 179 | new(customStatsHandler), 180 | false, 181 | func(_ sendBuffer) error { 182 | require.Fail(t, "notify called") 183 | return nil 184 | }, 185 | ) 186 | h.StartNotificationBatch(nil, 0) 187 | h.EndNotificationBatch() 188 | } 189 | 190 | var ignoredMetadata = ads.SubscriptionMetadata{} 191 | 192 | func TestPseudoDeltaSotWHandler(t *testing.T) { 193 | typeUrl := utils.GetTypeURL[*wrapperspb.BoolValue]() 194 | // This test relies on Bool being a pseudo delta resource type, so fail the test early otherwise 195 | require.True(t, utils.IsPseudoDeltaSotW(typeUrl)) 196 | 197 | l := NewTestHandlerLimiter() 198 | var lastRes *ads.SotWDiscoveryResponse 199 | h := newSotWHandler( 200 | testutils.Context(t), 201 | NoopLimiter{}, 202 | l, 203 | new(customStatsHandler), 204 | typeUrl, 205 | func(res *ads.SotWDiscoveryResponse) error { 206 | defer l.Done() 207 | lastRes = res 208 | return nil 209 | }, 210 | ) 211 | 212 | const foo, bar, baz = "foo", "bar", "baz" 213 | fooR := ads.NewResource(foo, "0", wrapperspb.Bool(true)) 214 | barR := ads.NewResource(bar, "0", wrapperspb.Bool(false)) 215 | bazR := ads.NewResource(baz, "0", wrapperspb.Bool(false)) 216 | h.Notify(foo, testutils.MustMarshal(t, fooR), ignoredMetadata) 217 | 218 | l.Release() 219 | require.Equal(t, typeUrl, lastRes.TypeUrl) 220 | require.ElementsMatch(t, []*anypb.Any{testutils.MustMarshal(t, fooR).Resource}, lastRes.Resources) 221 | 222 | const wait = 500 * time.Millisecond 223 | // PseudoDeltaSotW doesn't have a notion of deletions. A deleted resource simply never shows up again unless 224 | // it's recreated. The next call to Release should therefore block until the handler invokes l.reserve(), which it 225 | // should _not_ do until a resource is created. This test checks that that's the case by deleting foo then waiting 226 | // creating bar 500ms before creating bar, then checking how long Release blocked, which should be roughly 500ms. 227 | h.Notify(foo, nil, ignoredMetadata) 228 | go func() { 229 | time.Sleep(wait) 230 | h.Notify(bar, testutils.MustMarshal(t, barR), ignoredMetadata) 231 | }() 232 | 233 | start := time.Now() 234 | l.Release() 235 | require.WithinDuration(t, time.Now(), start.Add(wait), 10*time.Millisecond) 236 | require.ElementsMatch(t, []*anypb.Any{testutils.MustMarshal(t, bazR).Resource}, lastRes.Resources) 237 | } 238 | 239 | func TestHandlerBatchingWithIRV(t *testing.T) { 240 | const ( 241 | foo = "foo" 242 | bar = "bar" 243 | ) 244 | var released atomic.Bool 245 | ch := make(chan sendBuffer) 246 | handler := newHandler( 247 | testutils.Context(t), 248 | AnyTypeURL, 249 | NoopLimiter{}, 250 | NoopLimiter{}, 251 | new(customStatsHandler), 252 | false, 253 | func(resources sendBuffer) error { 254 | ch <- maps.Clone(resources) 255 | return nil 256 | }, 257 | ) 258 | notify := func(name string, resource *ads.RawResource) { 259 | handler.Notify(name, resource, ads.SubscriptionMetadata{}) 260 | } 261 | 262 | t.Run("partial update, foo not updated and bar updated", func(t *testing.T) { 263 | req := newDeltaReq([]string{"foo", "bar"}, map[string]string{"foo": "0", "bar": "0"}) 264 | handler.StartNotificationBatch(req.InitialResourceVersions, 0) 265 | fooResource := newRawResource(foo, "0") 266 | barResource := newRawResource(bar, "1") 267 | notify(foo, fooResource) 268 | notify(bar, barResource) 269 | released.Store(true) 270 | handler.EndNotificationBatch() 271 | checkSendBuffer(t, sendBuffer{ 272 | barResource.Name: serverstats.SentResource{Resource: barResource}, 273 | }, <-ch) 274 | }) 275 | 276 | t.Run("partial update, foo deleted and bar updated", func(t *testing.T) { 277 | req := newDeltaReq([]string{foo, bar}, map[string]string{foo: "0", bar: "0"}) 278 | handler.StartNotificationBatch(req.InitialResourceVersions, 0) 279 | barResource := newRawResource(bar, "1") 280 | notify(bar, barResource) 281 | released.Store(true) 282 | handler.EndNotificationBatch() 283 | checkSendBuffer(t, sendBuffer{ 284 | barResource.Name: serverstats.SentResource{Resource: barResource}, 285 | foo: serverstats.SentResource{Resource: nil}, 286 | }, <-ch) 287 | }) 288 | 289 | t.Run("partial update, foo deleted and bar updated with wildcard subscription", func(t *testing.T) { 290 | req := newDeltaReq([]string{ads.WildcardSubscription}, map[string]string{foo: "0", bar: "0"}) 291 | handler.StartNotificationBatch(req.InitialResourceVersions, 0) 292 | barResource := newRawResource(bar, "1") 293 | notify(bar, barResource) 294 | released.Store(true) 295 | handler.EndNotificationBatch() 296 | checkSendBuffer(t, sendBuffer{ 297 | barResource.Name: serverstats.SentResource{Resource: barResource}, 298 | foo: serverstats.SentResource{Resource: nil}, 299 | }, <-ch) 300 | }) 301 | } 302 | 303 | func newDeltaReq(subscribe []string, versions map[string]string) *ads.DeltaDiscoveryRequest { 304 | return &ads.DeltaDiscoveryRequest{ 305 | ResourceNamesSubscribe: subscribe, 306 | InitialResourceVersions: versions, 307 | } 308 | } 309 | 310 | func newRawResource(name string, version string) *ads.RawResource { 311 | return &ads.RawResource{ 312 | Name: name, 313 | Version: version, 314 | Resource: &anypb.Any{}, 315 | } 316 | 317 | } 318 | -------------------------------------------------------------------------------- /internal/server/limiter.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "time" 5 | 6 | "golang.org/x/time/rate" 7 | ) 8 | 9 | // handlerLimiter is an interface used by the handler implementation. It exists for the sole purpose 10 | // of testing and is trivially implemented by rate.Limiter using rateLimiterWrapper. It is not 11 | // exposed in this package's public API. 12 | type handlerLimiter interface { 13 | // reserve returns a channel that will receive the current time (or be closed) once the rate limit 14 | // clears. Callers should wait until this occurs before acting. The returned cancel function should 15 | // be invoked if the caller did not wait for the rate limit to clear, though it is safe to call even 16 | // if after the rate limit cleared. In other words, it is safe to invoke in a deferred expression. 17 | reserve() (reservation <-chan time.Time, cancel func()) 18 | } 19 | 20 | var _ handlerLimiter = (*rateLimiterWrapper)(nil) 21 | 22 | // rateLimiterWrapper implements handlerLimiter using a rate.Limiter 23 | type rateLimiterWrapper rate.Limiter 24 | 25 | func (w *rateLimiterWrapper) reserve() (reservation <-chan time.Time, cancel func()) { 26 | r := (*rate.Limiter)(w).Reserve() 27 | timer := time.NewTimer(r.Delay()) 28 | return timer.C, func() { 29 | // Stopping the timer cleans up any goroutines or schedules associated with this timer. Invoking this 30 | // after the timer fires is a noop. 31 | timer.Stop() 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /internal/server/limiter_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/require" 9 | "golang.org/x/time/rate" 10 | ) 11 | 12 | var _ handlerLimiter = (*TestRateLimiter)(nil) 13 | 14 | func NewTestHandlerLimiter() *TestRateLimiter { 15 | return &TestRateLimiter{ 16 | cond: sync.NewCond(new(sync.Mutex)), 17 | } 18 | } 19 | 20 | type TestRateLimiter struct { 21 | cond *sync.Cond 22 | ch chan time.Time 23 | wg sync.WaitGroup 24 | } 25 | 26 | func (l *TestRateLimiter) reserve() (<-chan time.Time, func()) { 27 | l.cond.L.Lock() 28 | defer l.cond.L.Unlock() 29 | 30 | if l.ch == nil { 31 | l.ch = make(chan time.Time) 32 | l.cond.Signal() 33 | } 34 | return l.ch, func() {} 35 | } 36 | 37 | func (l *TestRateLimiter) Release() { 38 | l.cond.L.Lock() 39 | if l.ch == nil { 40 | l.cond.Wait() 41 | } 42 | ch := l.ch 43 | l.ch = nil 44 | l.cond.L.Unlock() 45 | 46 | l.wg.Add(1) 47 | ch <- time.Now() 48 | l.wg.Wait() 49 | } 50 | 51 | // WaitForReserve waits for another goroutine to call reserve (if one hasn't already) 52 | func (l *TestRateLimiter) WaitForReserve() { 53 | l.cond.L.Lock() 54 | defer l.cond.L.Unlock() 55 | 56 | if l.ch == nil { 57 | l.cond.Wait() 58 | } 59 | } 60 | 61 | func (l *TestRateLimiter) Done() { 62 | l.wg.Done() 63 | } 64 | 65 | func TestHandlerLimiter(t *testing.T) { 66 | l := (*rateLimiterWrapper)(rate.NewLimiter(10, 1)) 67 | start := time.Now() 68 | ch1, _ := l.reserve() 69 | ch2, _ := l.reserve() 70 | 71 | const delta = float64(5 * time.Millisecond) 72 | require.InDelta(t, 0, (<-ch1).Sub(start), delta) 73 | require.InDelta(t, 100*time.Millisecond, (<-ch2).Sub(start), delta) 74 | } 75 | 76 | var closedReservation = func() chan time.Time { 77 | ch := make(chan time.Time) 78 | close(ch) 79 | return ch 80 | }() 81 | 82 | type NoopLimiter struct{} 83 | 84 | func (n NoopLimiter) reserve() (reservation <-chan time.Time, cancel func()) { 85 | return closedReservation, func() {} 86 | } 87 | -------------------------------------------------------------------------------- /internal/server/subscription_manager.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "context" 5 | "slices" 6 | "sync" 7 | 8 | "github.com/linkedin/diderot/ads" 9 | "github.com/linkedin/diderot/internal/utils" 10 | "google.golang.org/protobuf/proto" 11 | ) 12 | 13 | // ResourceLocator is a copy of the interface in the root package, to avoid import cycles. 14 | type ResourceLocator interface { 15 | Subscribe( 16 | streamCtx context.Context, 17 | typeURL, resourceName string, 18 | handler ads.RawSubscriptionHandler, 19 | ) (unsubscribe func()) 20 | } 21 | 22 | type SubscriptionManager[REQ proto.Message] interface { 23 | // ProcessSubscriptions handles subscribing/unsubscribing from the resources provided in the given 24 | // xDS request. This function will always invoke BatchSubscriptionHandler.StartNotificationBatch 25 | // before it starts processing the subscriptions and always complete with 26 | // BatchSubscriptionHandler.EndNotificationBatch. Since the cache implementation always notifies the 27 | // SubscriptionHandler with the current value of the subscribed resource, 28 | // BatchSubscriptionHandler.EndNotificationBatch will be invoked after the handler has been notified 29 | // of all the resources requested. 30 | ProcessSubscriptions(REQ) 31 | // IsSubscribedTo checks whether the client has subscribed to the given resource name. 32 | IsSubscribedTo(name string) bool 33 | // UnsubscribeAll cleans up any active subscriptions and disables the wildcard subscription if enabled. 34 | UnsubscribeAll() 35 | } 36 | 37 | // subscriptionManagerCore keeps track of incoming subscription and unsubscription requests, and 38 | // executes the corresponding actions against the underlying cache. It is meant to be embedded in 39 | // deltaSubscriptionManager and sotWSubscriptionManager to deduplicate the subscription tracking 40 | // logic. 41 | type subscriptionManagerCore struct { 42 | ctx context.Context 43 | locator ResourceLocator 44 | typeURL string 45 | handler BatchSubscriptionHandler 46 | sizeEstimator SendBufferSizeEstimator 47 | 48 | lock sync.Mutex 49 | subscriptions map[string]func() 50 | } 51 | 52 | func newSubscriptionManagerCore( 53 | ctx context.Context, 54 | locator ResourceLocator, 55 | typeURL string, 56 | handler BatchSubscriptionHandler, 57 | sizeEstimator SendBufferSizeEstimator, 58 | ) *subscriptionManagerCore { 59 | c := &subscriptionManagerCore{ 60 | ctx: ctx, 61 | locator: locator, 62 | typeURL: typeURL, 63 | handler: handler, 64 | sizeEstimator: sizeEstimator, 65 | subscriptions: make(map[string]func()), 66 | } 67 | // Ensure all the subscriptions managed by this subscription manager are cleaned up, otherwise they 68 | // will dangle forever in the cache and prevent the backing SubscriptionHandler from being collected 69 | // as well. 70 | context.AfterFunc(ctx, func() { 71 | c.UnsubscribeAll() 72 | }) 73 | return c 74 | } 75 | 76 | type deltaSubscriptionManager struct { 77 | *subscriptionManagerCore 78 | firstCallReceived bool 79 | } 80 | 81 | // NewDeltaSubscriptionManager creates a new SubscriptionManager specifically designed to handle the 82 | // Delta xDS protocol's subscription semantics. 83 | func NewDeltaSubscriptionManager( 84 | ctx context.Context, 85 | locator ResourceLocator, 86 | typeURL string, 87 | handler BatchSubscriptionHandler, 88 | sizeEstimator SendBufferSizeEstimator, 89 | ) SubscriptionManager[*ads.DeltaDiscoveryRequest] { 90 | return &deltaSubscriptionManager{ 91 | subscriptionManagerCore: newSubscriptionManagerCore(ctx, locator, typeURL, handler, sizeEstimator), 92 | } 93 | } 94 | 95 | type sotWSubscriptionManager struct { 96 | *subscriptionManagerCore 97 | receivedExplicitSubscriptions bool 98 | } 99 | 100 | // NewSotWSubscriptionManager creates a new SubscriptionManager specifically designed to handle the 101 | // State-of-the-World xDS protocol's subscription semantics. 102 | func NewSotWSubscriptionManager( 103 | ctx context.Context, 104 | locator ResourceLocator, 105 | typeURL string, 106 | handler BatchSubscriptionHandler, 107 | sizeEstimator SendBufferSizeEstimator, 108 | ) SubscriptionManager[*ads.SotWDiscoveryRequest] { 109 | return &sotWSubscriptionManager{ 110 | subscriptionManagerCore: newSubscriptionManagerCore(ctx, locator, typeURL, handler, sizeEstimator), 111 | } 112 | } 113 | 114 | // ProcessSubscriptions processes the subscriptions for a delta stream. It manages the implicit 115 | // wildcard subscription outlined in [the spec]. The server should default to the wildcard 116 | // subscription if the client's first request does not provide any resource names to explicitly 117 | // subscribe to. The client must then explicit unsubscribe from the wildcard. Subsequent requests 118 | // that do not provide any explicit resource names will not alter the current subscription state. 119 | // 120 | // [the spec]: https://www.envoyproxy.io/docs/envoy/latest/api-docs/xds_protocol.html#how-the-client-specifies-what-resources-to-return 121 | func (m *deltaSubscriptionManager) ProcessSubscriptions(req *ads.DeltaDiscoveryRequest) { 122 | subscribe, estimatedSize := m.cleanSubscriptionsAndEstimateSize( 123 | req.ResourceNamesSubscribe, req.InitialResourceVersions, 124 | ) 125 | 126 | m.handler.StartNotificationBatch(req.InitialResourceVersions, estimatedSize) 127 | defer m.handler.EndNotificationBatch() 128 | 129 | m.lock.Lock() 130 | defer m.lock.Unlock() 131 | 132 | if !m.firstCallReceived { 133 | m.firstCallReceived = true 134 | if len(subscribe) == 0 { 135 | subscribe = []string{ads.WildcardSubscription} 136 | } 137 | } 138 | 139 | for _, name := range subscribe { 140 | m.subscribe(name) 141 | } 142 | 143 | for _, name := range req.ResourceNamesUnsubscribe { 144 | m.unsubscribe(name) 145 | } 146 | } 147 | 148 | // ProcessSubscriptions processes the subscriptions for a state of the world stream. It manages the 149 | // implicit wildcard subscription outlined in [the spec]. The server should default to the wildcard 150 | // subscription if the client has not sent any resource names to explicitly subscribe to. After the 151 | // first request that provides explicit resource names, the implicit wildcard subscription should 152 | // disappear. 153 | // 154 | // [the spec]: https://www.envoyproxy.io/docs/envoy/latest/api-docs/xds_protocol.html#how-the-client-specifies-what-resources-to-return 155 | func (m *sotWSubscriptionManager) ProcessSubscriptions(req *ads.SotWDiscoveryRequest) { 156 | subscribe, estimatedSize := m.cleanSubscriptionsAndEstimateSize(req.ResourceNames, nil) 157 | 158 | // sotWSubscriptionManager does not support initial resource versions, so we pass nil. 159 | m.handler.StartNotificationBatch(nil, estimatedSize) 160 | defer m.handler.EndNotificationBatch() 161 | 162 | m.lock.Lock() 163 | defer m.lock.Unlock() 164 | 165 | m.receivedExplicitSubscriptions = m.receivedExplicitSubscriptions || len(subscribe) != 0 166 | if !m.receivedExplicitSubscriptions { 167 | subscribe = []string{ads.WildcardSubscription} 168 | } 169 | 170 | intersection := utils.Set[string]{} 171 | for _, name := range subscribe { 172 | if _, ok := m.subscriptions[name]; ok { 173 | intersection.Add(name) 174 | } 175 | } 176 | 177 | for name := range m.subscriptions { 178 | if !intersection.Contains(name) { 179 | m.unsubscribe(name) 180 | } 181 | } 182 | 183 | for _, name := range subscribe { 184 | if !intersection.Contains(name) { 185 | m.subscribe(name) 186 | } 187 | } 188 | } 189 | 190 | func (c *subscriptionManagerCore) IsSubscribedTo(name string) bool { 191 | c.lock.Lock() 192 | defer c.lock.Unlock() 193 | 194 | _, nameOk := c.subscriptions[name] 195 | _, wildcardOk := c.subscriptions[ads.WildcardSubscription] 196 | return nameOk || wildcardOk 197 | } 198 | 199 | func (c *subscriptionManagerCore) UnsubscribeAll() { 200 | c.lock.Lock() 201 | defer c.lock.Unlock() 202 | 203 | for name := range c.subscriptions { 204 | c.unsubscribe(name) 205 | } 206 | } 207 | 208 | func (c *subscriptionManagerCore) subscribe(name string) { 209 | c.unsubscribe(name) 210 | c.subscriptions[name] = c.locator.Subscribe(c.ctx, c.typeURL, name, c.handler) 211 | } 212 | 213 | func (c *subscriptionManagerCore) unsubscribe(name string) { 214 | if unsub, ok := c.subscriptions[name]; ok { 215 | unsub() 216 | delete(c.subscriptions, name) 217 | } 218 | } 219 | 220 | // cleanSubscriptionsAndEstimateSize clones the given slice and removes duplicate elements by sorting 221 | // it. This ensures that the server does not process the same subscription twice for the same 222 | // request. It then estimates the size of send buffer to pass to 223 | // [BatchSubscriptionHandler.StartNotificationBatch] if a [SendBufferSizeEstimator] was provided. 224 | func (c *subscriptionManagerCore) cleanSubscriptionsAndEstimateSize( 225 | resourceNamesSubscribe []string, 226 | initialResourceVersions map[string]string, 227 | ) (cleaned []string, size int) { 228 | cleaned = slices.Clone(resourceNamesSubscribe) 229 | slices.Sort(cleaned) 230 | cleaned = slices.Compact(cleaned) 231 | if len(initialResourceVersions) == 0 && c.sizeEstimator != nil { 232 | size = c.sizeEstimator.EstimateSubscriptionSize(c.ctx, c.typeURL, cleaned) 233 | } 234 | return cleaned, size 235 | } 236 | -------------------------------------------------------------------------------- /internal/utils/set.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "iter" 6 | "maps" 7 | "slices" 8 | ) 9 | 10 | type Set[T comparable] map[T]struct{} 11 | 12 | func NewSet[T comparable](elements ...T) Set[T] { 13 | s := make(Set[T], len(elements)) 14 | for _, t := range elements { 15 | s.Add(t) 16 | } 17 | return s 18 | } 19 | 20 | func (s Set[T]) Add(t T) bool { 21 | _, ok := s[t] 22 | if ok { 23 | return false 24 | } 25 | s[t] = struct{}{} 26 | return true 27 | } 28 | 29 | func (s Set[T]) Contains(t T) bool { 30 | _, ok := s[t] 31 | return ok 32 | } 33 | 34 | func (s Set[T]) Remove(t T) bool { 35 | _, ok := s[t] 36 | if !ok { 37 | return false 38 | } 39 | delete(s, t) 40 | return true 41 | } 42 | 43 | func (s Set[T]) String() string { 44 | return fmt.Sprint(slices.Collect(maps.Keys(s))) 45 | } 46 | 47 | func (s Set[T]) Values() iter.Seq[T] { 48 | return maps.Keys(s) 49 | } 50 | -------------------------------------------------------------------------------- /internal/utils/utils.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "encoding/base64" 5 | "encoding/binary" 6 | "encoding/hex" 7 | "slices" 8 | "strings" 9 | "time" 10 | 11 | types "github.com/envoyproxy/go-control-plane/pkg/resource/v3" 12 | "google.golang.org/protobuf/encoding/protowire" 13 | "google.golang.org/protobuf/proto" 14 | ) 15 | 16 | const ( 17 | // NonceLength is the length of the string returned by NewNonce. NewNonce encodes the current UNIX 18 | // time in nanos and the remaining chunks, encoded as 64-bit and 32-bit integers respectively, then 19 | // hex encoded. This means a nonce will always be 8 + 4 bytes, multiplied by 2 by the hex encoding. 20 | NonceLength = (8 + 4) * 2 21 | ) 22 | 23 | // NewNonce creates a new unique nonce based on the current UNIX time in nanos, always returning a 24 | // string of [NonceLength]. 25 | func NewNonce(remainingChunks int) string { 26 | return newNonce(time.Now(), remainingChunks) 27 | } 28 | 29 | func newNonce(now time.Time, remainingChunks int) string { 30 | // preallocating these buffers with constants (instead of doing `out = make([]byte, len(buf) * 2)`) 31 | // means the compiler will allocate them on the stack, instead of heap. This significantly reduces 32 | // the amount of garbage created by this function, as the only heap allocation will be the final 33 | // string(out), rather than all of these buffers. 34 | buf := make([]byte, NonceLength/2) 35 | out := make([]byte, NonceLength) 36 | 37 | binary.BigEndian.PutUint64(buf[:8], uint64(now.UnixNano())) 38 | binary.BigEndian.PutUint32(buf[8:], uint32(remainingChunks)) 39 | 40 | hex.Encode(out, buf) 41 | 42 | return string(out) 43 | } 44 | 45 | func GetTypeURL[T proto.Message]() string { 46 | var t T 47 | return getTypeURL(t) 48 | } 49 | 50 | func getTypeURL(t proto.Message) string { 51 | return types.APITypePrefix + string(t.ProtoReflect().Descriptor().FullName()) 52 | } 53 | 54 | // MapToProto serializes the given map using protobuf. It sorts the entries based on the key such that the same map 55 | // always produces the same output. It then encodes the entries by appending the key then value, and b64 encodes the 56 | // entire output. Note that the final b64 encoding step is critical as this function is intended to be used with 57 | // [ads.SotWDiscoveryResponse.Version], which is a string field. In protobuf, string fields must contain valid UTF-8 58 | // characters, and b64 encoding ensures that. 59 | func MapToProto(m map[string]string) string { 60 | if len(m) == 0 { 61 | return "" 62 | } 63 | 64 | var b []byte 65 | keys := make([]string, 0, len(m)) 66 | for k := range m { 67 | keys = append(keys, k) 68 | } 69 | slices.Sort(keys) 70 | 71 | for _, k := range keys { 72 | b = protowire.AppendString(b, k) 73 | b = protowire.AppendString(b, m[k]) 74 | } 75 | return base64.StdEncoding.EncodeToString(b) 76 | } 77 | 78 | // ProtoToMap is the inverse of MapToProto. It returns an error on any decoding or deserialization issues. 79 | func ProtoToMap(s string) (map[string]string, error) { 80 | if s == "" { 81 | return nil, nil 82 | } 83 | 84 | b, err := base64.StdEncoding.DecodeString(s) 85 | if err != nil { 86 | return nil, err 87 | } 88 | m := make(map[string]string) 89 | 90 | parse := func() (string, error) { 91 | s, n := protowire.ConsumeString(b) 92 | if n < 0 { 93 | return "", protowire.ParseError(n) 94 | } 95 | b = b[n:] 96 | return s, nil 97 | } 98 | 99 | for len(b) > 0 { 100 | k, err := parse() 101 | if err != nil { 102 | return nil, err 103 | } 104 | v, err := parse() 105 | if err != nil { 106 | return nil, err 107 | } 108 | m[k] = v 109 | } 110 | 111 | return m, nil 112 | } 113 | 114 | // IsPseudoDeltaSotW checks whether the given resource type url is intended to behave as a "pseudo 115 | // delta" resource. Instead of sending the entire state of the world for every resource change, the 116 | // server is expected to only send the changed resource. From [the spec]: 117 | // 118 | // In the SotW protocol variants, all resource types except for Listener and Cluster are grouped into 119 | // responses in the same way as in the incremental protocol variants. However, Listener and Cluster 120 | // resource types are handled differently: the server must include the complete state of the world, 121 | // meaning that all resources of the relevant type that are needed by the client must be included, 122 | // even if they did not change since the last response. 123 | // 124 | // In other words, for everything except Listener and Cluster, the server should only send the 125 | // changed resources, rather than every resource every time. 126 | // 127 | // [the spec]: https://www.envoyproxy.io/docs/envoy/latest/api-docs/xds_protocol#grouping-resources-into-responses 128 | func IsPseudoDeltaSotW(typeURL string) bool { 129 | return !(typeURL == types.ListenerType || typeURL == types.ClusterType) 130 | } 131 | 132 | // TrimTypeURL removes the leading "types.googleapis.com/" prefix from the given string. 133 | func TrimTypeURL(typeURL string) string { 134 | return strings.TrimPrefix(typeURL, types.APITypePrefix) 135 | } 136 | -------------------------------------------------------------------------------- /internal/utils/utils_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | "github.com/envoyproxy/go-control-plane/pkg/resource/v3" 9 | "github.com/linkedin/diderot/ads" 10 | "github.com/stretchr/testify/require" 11 | "google.golang.org/protobuf/encoding/protowire" 12 | ) 13 | 14 | func TestGetTypeURL(t *testing.T) { 15 | require.Equal(t, resource.ListenerType, GetTypeURL[*ads.Listener]()) 16 | require.Equal(t, resource.EndpointType, GetTypeURL[*ads.Endpoint]()) 17 | require.Equal(t, resource.ClusterType, GetTypeURL[*ads.Cluster]()) 18 | require.Equal(t, resource.RouteType, GetTypeURL[*ads.Route]()) 19 | } 20 | 21 | func TestProtoMap(t *testing.T) { 22 | t.Run("good", func(t *testing.T) { 23 | m := map[string]string{ 24 | "foo": "bar", 25 | "baz": "qux", 26 | "empty": "", 27 | "": "empty", 28 | } 29 | s := MapToProto(m) 30 | m2, err := ProtoToMap(s) 31 | require.NoError(t, err) 32 | require.Equal(t, m, m2) 33 | 34 | // Check that on a different invocation, the output remains the same 35 | require.Equal(t, s, MapToProto(m)) 36 | 37 | m2, err = ProtoToMap("") 38 | require.NoError(t, err) 39 | require.Empty(t, m2) 40 | }) 41 | t.Run("bad", func(t *testing.T) { 42 | _, err := ProtoToMap("1") 43 | require.Error(t, err) 44 | 45 | b := protowire.AppendString(nil, "foo") 46 | _, err = ProtoToMap(string(b)) 47 | require.Error(t, err) 48 | }) 49 | } 50 | 51 | func TestNewNonce(t *testing.T) { 52 | now := time.Now() 53 | t.Run("remainingChunks", func(t *testing.T) { 54 | for _, expected := range []int{0, 42} { 55 | nonce := newNonce(now, expected) 56 | require.Equal(t, fmt.Sprintf("%x%08x", now.UnixNano(), expected), nonce) 57 | actualRemainingChunks, err := ads.ParseRemainingChunksFromNonce(nonce) 58 | require.NoError(t, err) 59 | require.Equal(t, expected, actualRemainingChunks) 60 | } 61 | }) 62 | t.Run("badNonce", func(t *testing.T) { 63 | remaining, err := ads.ParseRemainingChunksFromNonce("foo") 64 | require.Error(t, err) 65 | require.Zero(t, remaining) 66 | }) 67 | t.Run("oldNonce", func(t *testing.T) { 68 | remaining, err := ads.ParseRemainingChunksFromNonce(fmt.Sprintf("%x", now.UnixNano())) 69 | require.Error(t, err) 70 | require.Zero(t, remaining) 71 | }) 72 | } 73 | -------------------------------------------------------------------------------- /stats/server/server_stats.go: -------------------------------------------------------------------------------- 1 | package serverstats 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/linkedin/diderot/ads" 8 | "google.golang.org/protobuf/proto" 9 | ) 10 | 11 | // Handler will be invoked with an event of the corresponding type when said event occurs. 12 | type Handler interface { 13 | HandleServerEvent(context.Context, Event) 14 | } 15 | 16 | // Event contains information about a specific event that happened in the server. 17 | type Event interface { 18 | isServerEvent() 19 | } 20 | 21 | // RequestReceived contains the stats of a request received by the server. 22 | type RequestReceived struct { 23 | // The received request, either [ads.SotWDiscoveryRequest] or [ads.DeltaDiscoveryRequest]. 24 | Req proto.Message 25 | // Whether the request is an ACK 26 | IsACK bool 27 | // Whether the request is a NACK. Note that this is an important stat that requires immediate human 28 | // intervention. 29 | IsNACK bool 30 | // The given duration represents the time it took to handle the request, i.e. validating it and 31 | // processing its subscriptions if necessary. It does not include the time for any of the 32 | // resources to be sent in a response. 33 | Duration time.Duration 34 | } 35 | 36 | func (s *RequestReceived) isServerEvent() {} 37 | 38 | // SentResource contains all the metadata about a resource sent by the server. Will be the 0-value 39 | // for any resource that was provided via the initial_resource_versions field which was not 40 | // explicitly subscribed to and did not exist. 41 | type SentResource struct { 42 | // The resource itself, nil if the resource is being deleted. 43 | Resource *ads.RawResource 44 | // The metadata for the resource and subscription. 45 | Metadata ads.SubscriptionMetadata 46 | // The time at which the resource was queued to be sent. This means that it does not include any time 47 | // spent in the granular or global rate limiters, or sending the response, which can take an arbitrarily 48 | // long time due to flow control. 49 | QueuedAt time.Time 50 | } 51 | 52 | // ResponseSent contains the stats of a response sent by the server. 53 | type ResponseSent struct { 54 | // The type URL of the resources in the response. 55 | TypeURL string 56 | // The resources that were sent. 57 | Resources map[string]SentResource 58 | // How long the Send operation took. This includes any time added by flow-control. 59 | Duration time.Duration 60 | } 61 | 62 | func (s *ResponseSent) isServerEvent() {} 63 | 64 | // TimeInGlobalRateLimiter contains the stats of the time spent in the global rate limiter. 65 | type TimeInGlobalRateLimiter struct { 66 | // How long the server waited for the global rate limiter to clear. 67 | Duration time.Duration 68 | } 69 | 70 | func (s *TimeInGlobalRateLimiter) isServerEvent() {} 71 | 72 | // ResourceMarshalError contains the stats for a resource that could not be marshaled. This 73 | // should be extremely rare and requires immediate attention. 74 | type ResourceMarshalError struct { 75 | // The name of the resource that could not be marshaled. 76 | ResourceName string 77 | // The resource that could not be marshaled. 78 | Resource proto.Message 79 | // The marshaling error. 80 | Err error 81 | } 82 | 83 | func (s *ResourceMarshalError) isServerEvent() {} 84 | 85 | // ResourceOverMaxSize contains the stats for a critical error that signals a resource will 86 | // never be received by clients that are subscribed to it. It likely requires immediate human 87 | // intervention. 88 | type ResourceOverMaxSize struct { 89 | // The resource that could not be sent. 90 | Resource *ads.RawResource 91 | // The encoded resource size. 92 | ResourceSize int 93 | // The maximum resource size (usually 4MB, gRPC's default max message size). 94 | MaxResourceSize int 95 | } 96 | 97 | func (s *ResourceOverMaxSize) isServerEvent() {} 98 | 99 | // UnknownResourceRequested indicates whether a resource that was subscribed never existed. This 100 | // should be rare, and can be indicative of a bug (either the client is requesting an unknown 101 | // resource because it is incorrectly configured, or the server is missing some resource that it is 102 | // expected to have). 103 | type UnknownResourceRequested struct { 104 | // The resource's type. 105 | TypeURL string 106 | // The resource's name. 107 | ResourceName string 108 | } 109 | 110 | func (s *UnknownResourceRequested) isServerEvent() {} 111 | 112 | // IRVMatchedResource represents stats for resources that matches the `initial_resource_versions` 113 | // provided by the client. 114 | type IRVMatchedResource struct { 115 | // The name of the resource 116 | ResourceName string 117 | // The resource itself, nil if the resource is being deleted. 118 | Resource *ads.RawResource 119 | } 120 | 121 | func (s *IRVMatchedResource) isServerEvent() {} 122 | -------------------------------------------------------------------------------- /test_xds_config.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "name": "testADSServer", 4 | "version": "1", 5 | "resource": { 6 | "@type": "type.googleapis.com/envoy.config.listener.v3.Listener", 7 | "name": "testADSServer", 8 | "apiListener": { 9 | "apiListener": { 10 | "@type": "type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager", 11 | "rds": { 12 | "configSource": { 13 | "ads": {}, 14 | "resourceApiVersion": "V3" 15 | }, 16 | "routeConfigName": "testADSServer" 17 | }, 18 | "httpFilters": [ 19 | { 20 | "name": "default", 21 | "typedConfig": { 22 | "@type": "type.googleapis.com/envoy.extensions.filters.http.router.v3.Router" 23 | } 24 | } 25 | ] 26 | } 27 | } 28 | } 29 | }, 30 | { 31 | "name": "testADSServer", 32 | "version": "1", 33 | "resource": { 34 | "@type": "type.googleapis.com/envoy.config.route.v3.RouteConfiguration", 35 | "name": "testADSServer", 36 | "virtualHosts": [ 37 | { 38 | "name": "testADSServer", 39 | "domains": [ 40 | "*" 41 | ], 42 | "routes": [ 43 | { 44 | "name": "default", 45 | "match": { 46 | "prefix": "" 47 | }, 48 | "route": { 49 | "cluster": "testADSServer" 50 | } 51 | } 52 | ] 53 | } 54 | ] 55 | } 56 | }, 57 | { 58 | "name": "testADSServer", 59 | "version": "1", 60 | "resource": { 61 | "@type": "type.googleapis.com/envoy.config.cluster.v3.Cluster", 62 | "name": "testADSServer", 63 | "type": "EDS", 64 | "edsClusterConfig": { 65 | "edsConfig": { 66 | "ads": {}, 67 | "resourceApiVersion": "V3" 68 | } 69 | } 70 | } 71 | } 72 | ] -------------------------------------------------------------------------------- /testutils/testutils.go: -------------------------------------------------------------------------------- 1 | package testutils 2 | 3 | import ( 4 | "context" 5 | "maps" 6 | "net" 7 | "slices" 8 | "testing" 9 | "time" 10 | 11 | "github.com/google/go-cmp/cmp" 12 | "github.com/linkedin/diderot/ads" 13 | "github.com/stretchr/testify/require" 14 | "google.golang.org/grpc" 15 | "google.golang.org/grpc/credentials/insecure" 16 | "google.golang.org/protobuf/encoding/prototext" 17 | "google.golang.org/protobuf/proto" 18 | ) 19 | 20 | func WithTimeout(t *testing.T, name string, timeout time.Duration, f func(t *testing.T)) { 21 | t.Run(name, func(t *testing.T) { 22 | t.Helper() 23 | done := make(chan struct{}) 24 | go func() { 25 | f(t) 26 | close(done) 27 | }() 28 | timer := time.NewTimer(timeout) 29 | defer timer.Stop() 30 | select { 31 | case <-timer.C: 32 | t.Fatalf("%q failed to complete in %s", t.Name(), timeout) 33 | case <-done: 34 | return 35 | } 36 | }) 37 | } 38 | 39 | func Context(tb testing.TB) context.Context { 40 | ctx, cancel := context.WithCancel(context.Background()) 41 | tb.Cleanup(cancel) 42 | return ctx 43 | } 44 | 45 | func ContextWithTimeout(tb testing.TB, timeout time.Duration) context.Context { 46 | ctx, cancel := context.WithTimeout(context.Background(), timeout) 47 | tb.Cleanup(cancel) 48 | return ctx 49 | } 50 | 51 | type Notification[T proto.Message] struct { 52 | Name string 53 | Resource *ads.Resource[T] 54 | Metadata ads.SubscriptionMetadata 55 | } 56 | 57 | type ChanSubscriptionHandler[T proto.Message] chan Notification[T] 58 | 59 | func (c ChanSubscriptionHandler[T]) Notify(name string, r *ads.Resource[T], metadata ads.SubscriptionMetadata) { 60 | c <- Notification[T]{ 61 | Name: name, 62 | Resource: r, 63 | Metadata: metadata, 64 | } 65 | } 66 | 67 | // This is the bare minimum required by the testify framework. *testing.T implements it, but this 68 | // interface is used for testing the test framework. 69 | type testingT interface { 70 | Logf(format string, args ...any) 71 | Errorf(format string, args ...any) 72 | FailNow() 73 | Helper() 74 | Fatalf(string, ...any) 75 | } 76 | 77 | var _ testingT = (*testing.T)(nil) 78 | var _ testingT = (*testing.B)(nil) 79 | 80 | type ExpectedNotification[T proto.Message] struct { 81 | Name string 82 | Resource *ads.Resource[T] 83 | } 84 | 85 | func ExpectDelete[T proto.Message](name string) ExpectedNotification[T] { 86 | return ExpectedNotification[T]{Name: name} 87 | } 88 | 89 | func ExpectUpdate[T proto.Message](r *ads.Resource[T]) ExpectedNotification[T] { 90 | return ExpectedNotification[T]{Name: r.Name, Resource: r} 91 | } 92 | 93 | func (c ChanSubscriptionHandler[T]) WaitForDelete( 94 | t testingT, 95 | expectedName string, 96 | ) Notification[T] { 97 | t.Helper() 98 | return c.WaitForNotifications(t, ExpectDelete[T](expectedName))[0] 99 | } 100 | 101 | func (c ChanSubscriptionHandler[T]) WaitForUpdate(t testingT, r *ads.Resource[T]) Notification[T] { 102 | t.Helper() 103 | return c.WaitForNotifications(t, ExpectUpdate(r))[0] 104 | } 105 | 106 | func (c ChanSubscriptionHandler[T]) WaitForNotifications(t testingT, notifications ...ExpectedNotification[T]) (out []Notification[T]) { 107 | t.Helper() 108 | 109 | expectedNotifications := make(map[string]int) 110 | for i, n := range notifications { 111 | expectedNotifications[n.Name] = i 112 | } 113 | 114 | out = make([]Notification[T], len(notifications)) 115 | 116 | for range notifications { 117 | var n Notification[T] 118 | select { 119 | case n = <-c: 120 | case <-time.After(5 * time.Second): 121 | t.Fatalf("Did not receive expected notification for one of: %v", 122 | slices.Collect(maps.Keys(expectedNotifications))) 123 | } 124 | 125 | idx, ok := expectedNotifications[n.Name] 126 | if !ok { 127 | require.Fail(t, "Received unexpected notification", n.Name) 128 | } 129 | expected := notifications[idx] 130 | out[idx] = n 131 | delete(expectedNotifications, n.Name) 132 | 133 | if expected.Resource != nil { 134 | require.NotNilf(t, n.Resource, "Expected update for %q, got deletion instead", expected.Name) 135 | ResourceEquals(t, expected.Resource, n.Resource) 136 | } else { 137 | require.Nilf(t, n.Resource, "Expected delete for %q, got update instead", expected.Name) 138 | } 139 | } 140 | 141 | require.Empty(t, expectedNotifications) 142 | 143 | return out 144 | } 145 | 146 | func ResourceEquals[T proto.Message](t testingT, expected, actual *ads.Resource[T]) { 147 | t.Helper() 148 | require.Equal(t, expected.Name, actual.Name) 149 | require.Equal(t, expected.Version, actual.Version) 150 | ProtoEquals(t, expected.Resource, actual.Resource) 151 | ProtoEquals(t, expected.Ttl, actual.Ttl) 152 | ProtoEquals(t, expected.CacheControl, actual.CacheControl) 153 | ProtoEquals(t, expected.Metadata, actual.Metadata) 154 | } 155 | 156 | func ProtoEquals(t testingT, expected, actual proto.Message) { 157 | t.Helper() 158 | if !proto.Equal(expected, actual) { 159 | t.Fatalf( 160 | "Messages not equal:\nexpected:%s\nactual :%s\n%s", 161 | expected, actual, 162 | cmp.Diff(prototext.Format(expected), prototext.Format(actual)), 163 | ) 164 | } 165 | } 166 | 167 | // FuncHandler is a SubscriptionHandler implementation that simply invokes a function. Note that the usual pattern of 168 | // having a literal func type implement the interface (e.g. http.HandlerFunc) does not work in this case because funcs 169 | // are not hashable and therefore cannot be used as map keys, which is often how SubscriptionHandlers are used. 170 | type FuncHandler[T proto.Message] struct { 171 | notify func(name string, r *ads.Resource[T], metadata ads.SubscriptionMetadata) 172 | } 173 | 174 | func (f *FuncHandler[T]) Notify(name string, r *ads.Resource[T], metadata ads.SubscriptionMetadata) { 175 | f.notify(name, r, metadata) 176 | } 177 | 178 | // NewSubscriptionHandler returns a SubscriptionHandler that invokes the given function when 179 | // SubscriptionHandler.Notify is invoked. 180 | func NewSubscriptionHandler[T proto.Message]( 181 | notify func(name string, r *ads.Resource[T], metadata ads.SubscriptionMetadata), 182 | ) *FuncHandler[T] { 183 | return &FuncHandler[T]{ 184 | notify: notify, 185 | } 186 | } 187 | 188 | type RawFuncHandler struct { 189 | t testingT 190 | notify func(name string, r *ads.RawResource, metadata ads.SubscriptionMetadata) 191 | } 192 | 193 | func (r *RawFuncHandler) Notify(name string, raw *ads.RawResource, metadata ads.SubscriptionMetadata) { 194 | r.notify(name, raw, metadata) 195 | } 196 | 197 | func (r *RawFuncHandler) ResourceMarshalError(name string, resource proto.Message, err error) { 198 | r.t.Fatalf("Unexpected resource marshal error for %q: %v\n%v", name, err, resource) 199 | } 200 | 201 | // NewRawSubscriptionHandler returns a RawSubscriptionHandler that invokes the given function when 202 | // SubscriptionHandler.Notify is invoked. 203 | func NewRawSubscriptionHandler( 204 | t testingT, 205 | notify func(name string, r *ads.RawResource, metadata ads.SubscriptionMetadata), 206 | ) *RawFuncHandler { 207 | return &RawFuncHandler{t: t, notify: notify} 208 | } 209 | 210 | // TestServer is instantiated with NewTestGRPCServer and serves to facilitate local testing against 211 | // gRPC service implementations. 212 | type TestServer struct { 213 | t *testing.T 214 | *grpc.Server 215 | net.Listener 216 | } 217 | 218 | // Start starts the backing gRPC server in a goroutine. Must be invoked _after_ registering the services. 219 | func (ts *TestServer) Start() { 220 | go func() { 221 | require.NoError(ts.t, ts.Server.Serve(ts.Listener)) 222 | }() 223 | } 224 | 225 | // Dial invokes DialContext with the given options and a context generated using Context. 226 | func (ts *TestServer) Dial(opts ...grpc.DialOption) *grpc.ClientConn { 227 | opts = append([]grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}, opts...) 228 | conn, err := grpc.NewClient(ts.AddrString(), opts...) 229 | require.NoError(ts.t, err) 230 | ts.t.Cleanup(func() { 231 | require.NoError(ts.t, conn.Close()) 232 | }) 233 | return conn 234 | } 235 | 236 | func (ts *TestServer) AddrString() string { 237 | return ts.Addr().String() 238 | } 239 | 240 | // NewTestGRPCServer is a utility function that spins up a TCP listener on a random local port along 241 | // with a grpc.Server. It cleans up any associated state using the Cleanup methods. Sample usage is 242 | // as follows: 243 | // 244 | // ts := NewTestGRPCServer(t) 245 | // discovery.RegisterAggregatedDiscoveryServiceServer(ts.Server, s) 246 | // ts.Start() 247 | // conn := ts.Dial() 248 | func NewTestGRPCServer(t *testing.T, opts ...grpc.ServerOption) *TestServer { 249 | ts := &TestServer{ 250 | t: t, 251 | Server: grpc.NewServer(opts...), 252 | } 253 | 254 | var err error 255 | ts.Listener, err = net.Listen("tcp", "localhost:0") 256 | require.NoError(t, err) 257 | 258 | t.Cleanup(func() { 259 | ts.Server.Stop() 260 | }) 261 | 262 | return ts 263 | } 264 | 265 | func MustMarshal[T proto.Message](t testingT, r *ads.Resource[T]) *ads.RawResource { 266 | marshaled, err := r.Marshal() 267 | require.NoError(t, err) 268 | return marshaled 269 | } 270 | -------------------------------------------------------------------------------- /testutils/testutils_test.go: -------------------------------------------------------------------------------- 1 | package testutils 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/linkedin/diderot/ads" 8 | "github.com/stretchr/testify/require" 9 | "google.golang.org/protobuf/types/known/wrapperspb" 10 | ) 11 | 12 | var failNowInvoked = new(byte) 13 | 14 | type testingTMock testing.T 15 | 16 | func (t *testingTMock) Errorf(format string, args ...any) { 17 | (*testing.T)(t).Logf(format, args...) 18 | } 19 | 20 | func (t *testingTMock) Fatalf(format string, args ...any) { 21 | (*testing.T)(t).Logf(format, args...) 22 | t.FailNow() 23 | } 24 | 25 | func (t *testingTMock) FailNow() { 26 | panic(failNowInvoked) 27 | } 28 | 29 | func (t *testingTMock) Helper() { 30 | } 31 | 32 | func TestChanSubscriptionHandler_WaitForNotification(t *testing.T) { 33 | const foo = "foo" 34 | 35 | expected := wrapperspb.Int64(42) 36 | version := "0" 37 | resource := &ads.Resource[*wrapperspb.Int64Value]{ 38 | Name: foo, 39 | Version: version, 40 | Resource: expected, 41 | } 42 | 43 | tests := []struct { 44 | name string 45 | shouldFail bool 46 | test func(mock *testingTMock, h ChanSubscriptionHandler[*wrapperspb.Int64Value]) 47 | }{ 48 | { 49 | name: "receive different name", 50 | shouldFail: true, 51 | test: func(mock *testingTMock, h ChanSubscriptionHandler[*wrapperspb.Int64Value]) { 52 | metadata := ads.SubscriptionMetadata{SubscribedAt: time.Now()} 53 | h.Notify("bar", nil, metadata) 54 | h.WaitForDelete(mock, foo) 55 | }, 56 | }, 57 | { 58 | name: "expect delete", 59 | shouldFail: false, 60 | test: func(mock *testingTMock, h ChanSubscriptionHandler[*wrapperspb.Int64Value]) { 61 | metadata := ads.SubscriptionMetadata{SubscribedAt: time.Now()} 62 | h.Notify(foo, nil, metadata) 63 | h.WaitForDelete(mock, foo) 64 | }, 65 | }, 66 | { 67 | name: "expect delete, get update", 68 | shouldFail: true, 69 | test: func(mock *testingTMock, h ChanSubscriptionHandler[*wrapperspb.Int64Value]) { 70 | metadata := ads.SubscriptionMetadata{SubscribedAt: time.Now()} 71 | h.Notify(foo, resource, metadata) 72 | h.WaitForDelete(mock, foo) 73 | }, 74 | }, 75 | { 76 | name: "expect update", 77 | shouldFail: false, 78 | test: func(mock *testingTMock, h ChanSubscriptionHandler[*wrapperspb.Int64Value]) { 79 | metadata := ads.SubscriptionMetadata{SubscribedAt: time.Now()} 80 | h.Notify(foo, resource, metadata) 81 | h.WaitForUpdate(mock, resource) 82 | }, 83 | }, 84 | { 85 | name: "expect update, get delete", 86 | shouldFail: true, 87 | test: func(mock *testingTMock, h ChanSubscriptionHandler[*wrapperspb.Int64Value]) { 88 | metadata := ads.SubscriptionMetadata{SubscribedAt: time.Now()} 89 | h.Notify(foo, nil, metadata) 90 | h.WaitForUpdate(mock, resource) 91 | }, 92 | }, 93 | { 94 | name: "received different value", 95 | shouldFail: true, 96 | test: func(mock *testingTMock, h ChanSubscriptionHandler[*wrapperspb.Int64Value]) { 97 | metadata := ads.SubscriptionMetadata{SubscribedAt: time.Now()} 98 | h.Notify(foo, resource, metadata) 99 | h.WaitForUpdate(mock, ads.NewResource[*wrapperspb.Int64Value](foo, version, wrapperspb.Int64(27))) 100 | }, 101 | }, 102 | } 103 | 104 | for _, test := range tests { 105 | t.Run(test.name, func(t *testing.T) { 106 | h := make(ChanSubscriptionHandler[*wrapperspb.Int64Value], 1) 107 | mock := (*testingTMock)(t) 108 | if test.shouldFail { 109 | require.PanicsWithValuef(t, failNowInvoked, func() { 110 | test.test(mock, h) 111 | }, "did not panic!") 112 | } else { 113 | test.test(mock, h) 114 | } 115 | }) 116 | } 117 | 118 | } 119 | -------------------------------------------------------------------------------- /type.go: -------------------------------------------------------------------------------- 1 | package diderot 2 | 3 | import ( 4 | "github.com/linkedin/diderot/ads" 5 | "github.com/linkedin/diderot/internal/utils" 6 | "google.golang.org/protobuf/proto" 7 | ) 8 | 9 | // typeReference is the only implementation of the Type and, by extension, the TypeReference 10 | // interface. It is not exposed publicly to ensure that all instances are generated through TypeOf, 11 | // which uses reflection on the type parameter to determine the type URL. This is to avoid potential 12 | // runtime complications due to invalid type URL strings. 13 | type typeReference[T proto.Message] string 14 | 15 | // TypeReference is a superset of the Type interface which captures the actual runtime type. 16 | type TypeReference[T proto.Message] interface { 17 | Type 18 | } 19 | 20 | // Type is a type reference for a type that can be cached. Only accessible through TypeOf. 21 | type Type interface { 22 | // URL returns the type URL for this Type. 23 | URL() string 24 | // TrimmedURL returns the type URL for this Type without the leading "types.googleapis.com/" prefix. 25 | // This string is useful when constructing xdstp URLs. 26 | TrimmedURL() string 27 | // NewCache is the untyped equivalent of this package's NewCache. The returned RawCache still 28 | // retains the runtime type information and can be safely cast to the corresponding Cache type. 29 | NewCache() RawCache 30 | // NewPrioritizedCache is the untyped equivalent of this package's NewPrioritizedCache. The returned 31 | // RawCache instances can be safely cast to the corresponding Cache type. 32 | NewPrioritizedCache(prioritySlots int) []RawCache 33 | 34 | isSubscribedTo(c RawCache, name string, handler ads.RawSubscriptionHandler) bool 35 | subscribe(c RawCache, name string, handler ads.RawSubscriptionHandler) 36 | unsubscribe(c RawCache, name string, handler ads.RawSubscriptionHandler) 37 | } 38 | 39 | func (t typeReference[T]) URL() string { 40 | return string(t) 41 | } 42 | 43 | func (t typeReference[T]) TrimmedURL() string { 44 | return utils.TrimTypeURL(t.URL()) 45 | } 46 | 47 | func (t typeReference[T]) NewCache() RawCache { 48 | return NewCache[T]() 49 | } 50 | 51 | func (t typeReference[T]) NewPrioritizedCache(prioritySlots int) []RawCache { 52 | caches := NewPrioritizedCache[T](prioritySlots) 53 | out := make([]RawCache, len(caches)) 54 | for i, c := range caches { 55 | out[i] = c 56 | } 57 | return out 58 | } 59 | 60 | type wrappedHandler[T proto.Message] struct { 61 | ads.RawSubscriptionHandler 62 | } 63 | 64 | func (w wrappedHandler[T]) Notify(name string, r *ads.Resource[T], metadata ads.SubscriptionMetadata) { 65 | var raw *ads.RawResource 66 | if r != nil { 67 | var err error 68 | raw, err = r.Marshal() 69 | if err != nil { 70 | w.RawSubscriptionHandler.ResourceMarshalError(name, r.Resource, err) 71 | return 72 | } 73 | } 74 | w.RawSubscriptionHandler.Notify(name, raw, metadata) 75 | } 76 | 77 | // toGenericHandler wraps the given RawSubscriptionHandler into a typed SubscriptionHandler. Multiple 78 | // invocations of this function with the same RawSubscriptionHandler always return a semantically 79 | // equivalent value, meaning it's possible to do the following, without needing to explicitly store 80 | // and reuse the returned SubscriptionHandler: 81 | // 82 | // var c Cache[*ads.Endpoint] 83 | // var rawHandler RawSubscriptionHandler 84 | // c.Subscribe("foo", ToGenericHandler[*ads.Endpoint](rawHandler)) 85 | // c.Unsubscribe("foo", ToGenericHandler[*ads.Endpoint](rawHandler)) 86 | func (t typeReference[T]) toGenericHandler(raw ads.RawSubscriptionHandler) ads.SubscriptionHandler[T] { 87 | return wrappedHandler[T]{raw} 88 | } 89 | 90 | func (t typeReference[T]) isSubscribedTo(c RawCache, name string, handler ads.RawSubscriptionHandler) bool { 91 | return c.(Cache[T]).IsSubscribedTo(name, t.toGenericHandler(handler)) 92 | } 93 | 94 | func (t typeReference[T]) subscribe(c RawCache, name string, handler ads.RawSubscriptionHandler) { 95 | c.(Cache[T]).Subscribe(name, t.toGenericHandler(handler)) 96 | } 97 | 98 | func (t typeReference[T]) unsubscribe(c RawCache, name string, handler ads.RawSubscriptionHandler) { 99 | c.(Cache[T]).Unsubscribe(name, t.toGenericHandler(handler)) 100 | } 101 | 102 | // TypeOf returns a TypeReference that corresponds to the type parameter. 103 | func TypeOf[T proto.Message]() TypeReference[T] { 104 | return typeReference[T](utils.GetTypeURL[T]()) 105 | } 106 | 107 | // IsSubscribedTo checks whether the given handler is subscribed to the given named resource by invoking 108 | // the underlying generic API [diderot.Cache.IsSubscribedTo]. 109 | func IsSubscribedTo(c RawCache, name string, handler ads.RawSubscriptionHandler) bool { 110 | return c.Type().isSubscribedTo(c, name, handler) 111 | } 112 | 113 | // Subscribe registers the handler as a subscriber of the given named resource by invoking the 114 | // underlying generic API [diderot.Cache.Subscribe]. 115 | func Subscribe(c RawCache, name string, handler ads.RawSubscriptionHandler) { 116 | c.Type().subscribe(c, name, handler) 117 | } 118 | 119 | // Unsubscribe unregisters the handler as a subscriber of the given named resource by invoking the 120 | // underlying generic API [diderot.Cache.Unsubscribe]. 121 | func Unsubscribe(c RawCache, name string, handler ads.RawSubscriptionHandler) { 122 | c.Type().unsubscribe(c, name, handler) 123 | } 124 | -------------------------------------------------------------------------------- /type_test.go: -------------------------------------------------------------------------------- 1 | package diderot 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/linkedin/diderot/ads" 8 | "github.com/linkedin/diderot/testutils" 9 | "github.com/stretchr/testify/require" 10 | "google.golang.org/protobuf/types/known/wrapperspb" 11 | ) 12 | 13 | func TestType(t *testing.T) { 14 | tests := []struct { 15 | Name string 16 | UseRawSetter bool 17 | }{ 18 | { 19 | Name: "typed", 20 | UseRawSetter: false, 21 | }, 22 | { 23 | Name: "raw", 24 | UseRawSetter: true, 25 | }, 26 | } 27 | 28 | for _, test := range tests { 29 | t.Run(test.Name, func(t *testing.T) { 30 | c := NewCache[*wrapperspb.BoolValue]() 31 | 32 | const foo = "foo" 33 | 34 | r := &ads.Resource[*wrapperspb.BoolValue]{ 35 | Name: foo, 36 | Version: "0", 37 | Resource: wrapperspb.Bool(true), 38 | } 39 | if test.UseRawSetter { 40 | require.NoError(t, c.SetRaw(testutils.MustMarshal(t, r), time.Time{})) 41 | } else { 42 | c.SetResource(r, time.Time{}) 43 | } 44 | 45 | testutils.ResourceEquals(t, r, c.Get(foo)) 46 | 47 | c.Clear(foo, time.Time{}) 48 | require.Nil(t, c.Get(foo)) 49 | }) 50 | } 51 | } 52 | --------------------------------------------------------------------------------