From 3a4c65c1c198b28ff4d4cef27c4c31db12b694b0 Mon Sep 17 00:00:00 2001 From: PapaCharlie Date: Tue, 9 Jul 2024 20:09:33 -0400 Subject: [PATCH] Initial commit --- .github/workflows/build.yml | 26 + .github/workflows/golangci-lint.yml | 26 + .gitignore | 13 + .golangci.yaml | 13 + CONTRIBUTING | 1 + LICENSE | 28 + Makefile | 87 ++ NOTICE | 4 + README.md | 29 + ads/ads.go | 257 ++++++ ads/glob_collection_url.go | 162 ++++ ads/glob_collection_url_test.go | 158 ++++ cache.go | 377 +++++++++ cache_test.go | 794 ++++++++++++++++++ doc.go | 79 ++ examples/quickstart/main.go | 117 +++ go.mod | 36 + go.sum | 58 ++ internal/cache/glob_collection.go | 119 +++ internal/cache/glob_collections_map.go | 111 +++ internal/cache/resource_map.go | 153 ++++ internal/cache/subscriber_set.go | 130 +++ internal/cache/subscriber_set_test.go | 97 +++ internal/cache/subscription_type.go | 23 + internal/cache/subscription_type_test.go | 13 + internal/cache/watchable_value.go | 412 ++++++++++ internal/server/handlers.go | 364 +++++++++ internal/server/handlers_bench_test.go | 69 ++ internal/server/handlers_delta.go | 204 +++++ internal/server/handlers_delta_test.go | 178 ++++ internal/server/handlers_test.go | 231 ++++++ internal/server/limiter.go | 33 + internal/server/limiter_test.go | 86 ++ internal/server/subscription_manager.go | 216 +++++ internal/utils/set.go | 44 + internal/utils/utils.go | 125 +++ internal/utils/utils_test.go | 51 ++ server.go | 479 +++++++++++ server_test.go | 984 +++++++++++++++++++++++ stats/server/server_stats.go | 99 +++ test_xds_config.json | 72 ++ testutils/testutils.go | 266 ++++++ testutils/testutils_test.go | 118 +++ type.go | 123 +++ type_test.go | 51 ++ 45 files changed, 7116 insertions(+) create mode 100644 .github/workflows/build.yml create mode 100644 .github/workflows/golangci-lint.yml create mode 100644 .gitignore create mode 100644 .golangci.yaml create mode 100644 CONTRIBUTING create mode 100644 LICENSE create mode 100644 Makefile create mode 100644 NOTICE create mode 100644 README.md create mode 100644 ads/ads.go create mode 100644 ads/glob_collection_url.go create mode 100644 ads/glob_collection_url_test.go create mode 100644 cache.go create mode 100644 cache_test.go create mode 100644 doc.go create mode 100644 examples/quickstart/main.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/cache/glob_collection.go create mode 100644 internal/cache/glob_collections_map.go create mode 100644 internal/cache/resource_map.go create mode 100644 internal/cache/subscriber_set.go create mode 100644 internal/cache/subscriber_set_test.go create mode 100644 internal/cache/subscription_type.go create mode 100644 internal/cache/subscription_type_test.go create mode 100644 internal/cache/watchable_value.go create mode 100644 internal/server/handlers.go create mode 100644 internal/server/handlers_bench_test.go create mode 100644 internal/server/handlers_delta.go create mode 100644 internal/server/handlers_delta_test.go create mode 100644 internal/server/handlers_test.go create mode 100644 internal/server/limiter.go create mode 100644 internal/server/limiter_test.go create mode 100644 internal/server/subscription_manager.go create mode 100644 internal/utils/set.go create mode 100644 internal/utils/utils.go create mode 100644 internal/utils/utils_test.go create mode 100644 server.go create mode 100644 server_test.go create mode 100644 stats/server/server_stats.go create mode 100644 test_xds_config.json create mode 100644 testutils/testutils.go create mode 100644 testutils/testutils_test.go create mode 100644 type.go create mode 100644 type_test.go diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..249174c --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,26 @@ +name: build + +on: + push: + branches: + - master + pull_request: + +permissions: + contents: read + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.23.x' + - name: Install + run: go get -v . + - name: Build + run: make build + - name: Test + run: make test TESTVERBOSE=-v diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml new file mode 100644 index 0000000..5b95a86 --- /dev/null +++ b/.github/workflows/golangci-lint.yml @@ -0,0 +1,26 @@ +name: golangci-lint + +on: + push: + branches: + - master + pull_request: + +permissions: + contents: read + # Optional: allow read access to pull request. Use with `only-new-issues` option. + # pull-requests: read + +jobs: + golangci: + name: lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: stable + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: v1.60 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f7d5008 --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +.DS_store +.*.swp +.*.swo +*.iml +*.ipr +*.iws +*.sublime-* +.direnv/ +.gradle/ +.idea/ +.vscode/ +*.prof +.cov diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..5582794 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,13 @@ +linters: + enable: + - bodyclose + - errname + - errorlint + - exhaustive + - goconst + - gofmt + - goimports + - gocritic + - predeclared + - usestdlibvars + - unused diff --git a/CONTRIBUTING b/CONTRIBUTING new file mode 100644 index 0000000..4f1182a --- /dev/null +++ b/CONTRIBUTING @@ -0,0 +1 @@ +As a contributor, you represent that the code you submit is your original work or that of your employer (in which case you represent you have the right to bind your employer). By submitting code, you (and, if applicable, your employer) are licensing the submitted code to LinkedIn and the open source community subject to the BSD 2-Clause license. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..cf8d69e --- /dev/null +++ b/LICENSE @@ -0,0 +1,28 @@ +BSD 2-CLAUSE LICENSE + +Copyright 2024 LinkedIn Corporation +All Rights Reserved. + +Redistribution and use in source and binary forms, with or +without modification, are permitted provided that the following +conditions are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following +disclaimer in the documentation and/or other materials provided +with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..fa8b6eb --- /dev/null +++ b/Makefile @@ -0,0 +1,87 @@ +PACKAGE = github.com/linkedin/diderot +SOURCE_FILES = $(wildcard $(shell git ls-files)) +PROFILES = out +COVERAGE = $(PROFILES)/diderot.cov +GOBIN = $(shell go env GOPATH)/bin + +# "all" is invoked on a bare "make" call since it's the first recipe. It just formats the code and +# checks that all packages can be compiled +.PHONY: all +all: fmt build + +build: + go build -v ./... + go test -v -c -o /dev/null $$(go list -f '{{if .TestGoFiles}}{{.ImportPath}}{{end}}' ./...) + +tidy: + go mod tidy + +vet: + go vet ./... + +$(GOBIN)/goimports: + go install golang.org/x/tools/cmd/goimports@latest + +.PHONY: fmt +fmt: $(GOBIN)/goimports + $(GOBIN)/goimports -w . + +# Can be used to change the number of tests run, defaults to 1 to prevent caching +TESTCOUNT = 1 +# Can be used to change the verobosity of tests: make test TESTVERBOSE=-v +TESTVERBOSE = +# Can be used to generate coverage reports for a specific package +COVERPKG = $(PACKAGE) +# Can be used to change which package gets tested, defaults to all packages. +TESTPKG = ./... + +test: $(COVERAGE) +$(COVERAGE): + @mkdir -p $(@D) + go test -race -coverprofile=$(COVERAGE) -coverpkg=$(COVERPKG)/... -count=$(TESTCOUNT) $(TESTVERBOSE) $(TESTPKG) + +coverage: $(COVERAGE) + go tool cover -html=$(COVERAGE) + +profile_cache: + $(MAKE) -B $(PROFILES)/BenchmarkCacheThroughput.bench BENCH_PKG=./cache + +profile_handlers: + $(MAKE) -B $(PROFILES)/BenchmarkHandlers.bench BENCH_PKG=./server + +BENCHCOUNT = 1 +BENCHTIME = 1s + +$(PROFILES)/%.bench: +ifdef BENCH_PKG + $(eval BENCHBIN=$(PROFILES)/$*) + mkdir -p $(PROFILES) + go test -c \ + -o $(BENCHBIN) \ + ./$(BENCH_PKG) + cd $(BENCH_PKG) && $(BENCHBIN) \ + -test.count $(BENCHCOUNT) \ + -test.benchmem \ + -test.bench="^$*$$" \ + -test.cpuprofile $(PROFILES)/$*.cpu \ + -test.memprofile $(PROFILES)/$*.mem \ + -test.blockprofile $(PROFILES)/$*.block \ + -test.benchtime $(BENCHTIME) \ + -test.run "^$$" $(BENCHVERBOSE) \ + . | tee $(abspath $@) $(abspath $(BENCHOUT)) +else + $(error BENCH_PKG undefined) +endif +ifdef OPEN_PROFILES + go tool pprof $(BENCHBIN) $(PROFILES)/$*.cpu <<< web + go tool pprof $(PROFILES)/$*.mem <<< web +else + $(info Not opening profiles since OPEN_PROFILES is not set) +endif + +$(GOBIN)/pkgsite: + go install golang.org/x/pkgsite/cmd/pkgsite@latest + +docs: $(GOBIN)/pkgsite + $(GOBIN)/pkgsite -open . + diff --git a/NOTICE b/NOTICE new file mode 100644 index 0000000..802e113 --- /dev/null +++ b/NOTICE @@ -0,0 +1,4 @@ +Copyright 2024 LinkedIn Corporation +All Rights Reserved. + +Licensed under the BSD 2-Clause License (the "License"). See License in the project root for license information. diff --git a/README.md b/README.md new file mode 100644 index 0000000..3dad827 --- /dev/null +++ b/README.md @@ -0,0 +1,29 @@ +# Diderot +(pronounced dee-duh-row) + +--- + +Diderot is a server implementation of +the [xDS protocol](https://www.envoyproxy.io/docs/envoy/latest/api-docs/xds_protocol) that makes it extremely easy and +efficient to implement a control plane for your Envoy and gRPC services. For the most up-to-date information, please +visit the [documentation](https://pkg.go.dev/github.com/linkedin/diderot). + +## Quick Start Guide +The only thing you need to implement to make your resources available via xDS is a +`diderot.ResourceLocator`([link](https://pkg.go.dev/github.com/linkedin/diderot#ResourceLocator)). It is the interface +exposed by the [ADS server implementation](https://pkg.go.dev/github.com/linkedin/diderot#ADSServer) which should +contain the business logic of all your resource definitions and how to find them. To facilitate this implementation, +Diderot provides an efficient, low-resource [cache](https://pkg.go.dev/github.com/linkedin/diderot#Cache) that supports +highly concurrent updates. By leveraging the cache implementation for the heavy lifting, you will be able to focus on +the meaningful part of operating your own xDS control plane: your resource definitions. + +Once you have implemented your `ResourceLocator`, you can simply drop in a `diderot.ADSServer` to your gRPC service, and +you're ready to go! Please refer to the [examples/quickstart](examples/quickstart/main.go) package + +## Features +Diderot's ADS server implementation is a faithful implementation of the xDS protocol. This means it implements both the +State-of-the-World and Delta/Incremental variants. It supports advanced features such as +[glob collections](https://github.com/cncf/xds/blob/main/proposals/TP1-xds-transport-next.md#glob), unlocking the more +efficient alternative to the `EDS` stage: `LEDS` +([design doc](https://docs.google.com/document/d/1aZ9ddX99BOWxmfiWZevSB5kzLAfH2TS8qQDcCBHcfSE/edit#heading=h.mmb97owcrx3c)). + diff --git a/ads/ads.go b/ads/ads.go new file mode 100644 index 0000000..97a9511 --- /dev/null +++ b/ads/ads.go @@ -0,0 +1,257 @@ +/* +Package ads provides a set of utilities and definitions around the Aggregated Discovery Service xDS +protocol (ADS), such as convenient type aliases, constants and core definitions. +*/ +package ads + +import ( + "log/slog" + "sync" + "time" + + cluster "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + endpoint "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + listener "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + route "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + tls "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + discovery "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" + runtime "github.com/envoyproxy/go-control-plane/envoy/service/runtime/v3" + types "github.com/envoyproxy/go-control-plane/pkg/resource/v3" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/durationpb" +) + +// Alias to xDS types, for convenience and brevity. +type ( + // Server is the core interface that needs to be implemented by an xDS control plane. The + // "Aggregated" service (i.e. ADS, the name of this package) service is type agnostic, the desired + // type is specified in the request. This avoids the need for clients to open multiple streams when + // requesting different types, along with not needing new service definitions such as + // [github.com/envoyproxy/go-control-plane/envoy/service/endpoint/v3.EndpointDiscoveryServiceServer]. + Server = discovery.AggregatedDiscoveryServiceServer + // Node is an alias for the client information included in both Delta and SotW requests [core.Node]. + Node = core.Node + + // SotWClient is an alias for the state-of-the-world client type + // [discovery.AggregatedDiscoveryService_StreamAggregatedResourcesClient]. + SotWClient = discovery.AggregatedDiscoveryService_StreamAggregatedResourcesClient + // SotWStream is an alias for the state-of-the-world stream type for the server + // [discovery.AggregatedDiscoveryService_StreamAggregatedResourcesServer]. + SotWStream = discovery.AggregatedDiscoveryService_StreamAggregatedResourcesServer + // SotWDiscoveryRequest is an alias for the state-of-the-world request type + // [discovery.DiscoveryRequest]. + SotWDiscoveryRequest = discovery.DiscoveryRequest + // SotWDiscoveryResponse is an alias for the state-of-the-world response type + // [discovery.DiscoveryResponse]. + SotWDiscoveryResponse = discovery.DiscoveryResponse + + // DeltaClient is an alias for the delta client type + // [discovery.AggregatedDiscoveryService_DeltaAggregatedResourcesClient]. + DeltaClient = discovery.AggregatedDiscoveryService_DeltaAggregatedResourcesClient + // DeltaStream is an alias for the delta (also known as incremental) stream type for the server + // [discovery.AggregatedDiscoveryService_DeltaAggregatedResourcesServer]. + DeltaStream = discovery.AggregatedDiscoveryService_DeltaAggregatedResourcesServer + // DeltaDiscoveryRequest is an alias for the delta request type [discovery.DeltaDiscoveryRequest]. + DeltaDiscoveryRequest = discovery.DeltaDiscoveryRequest + // DeltaDiscoveryResponse is an alias for the delta response type [discovery.DeltaDiscoveryResponse]. + DeltaDiscoveryResponse = discovery.DeltaDiscoveryResponse + + // RawResource is a type alias for the core ADS type [discovery.Resource]. It is "raw" only in + // contrast to the [*Resource] type defined in this package, which preserves the underlying + // resource's type as a generic parameter. + RawResource = discovery.Resource +) + +// NewResource is a convenience method for creating a new [*Resource]. +func NewResource[T proto.Message](name, version string, t T) *Resource[T] { + return &Resource[T]{ + Name: name, + Version: version, + Resource: t, + } +} + +// Resource is the typed equivalent of [RawResource] in that it preserves the underlying resource's +// type at compile time. It defines the same fields as [RawResource] (except for unsupported fields +// such as [ads.RawResource.Aliases]), and can be trivially serialized to a [RawResource] with +// Marshal. It is undefined behavior to modify a [Resource] after creation. +type Resource[T proto.Message] struct { + Name string + Version string + Resource T + Ttl *durationpb.Duration + CacheControl *discovery.Resource_CacheControl + Metadata *core.Metadata + + marshalOnce sync.Once + marshaled *RawResource + marshalErr error +} + +// Marshal returns the serialized version of this Resource. Note that this result is cached, and can +// be called repeatedly and from multiple goroutines. +func (r *Resource[T]) Marshal() (*RawResource, error) { + r.marshalOnce.Do(func() { + var out *anypb.Any + out, r.marshalErr = anypb.New(r.Resource) + if r.marshalErr != nil { + // This shouldn't really ever happen, especially when serializing to Any + slog.Error( + "Failed to serialize proto", + "msg", r.Resource, + "type", string(r.Resource.ProtoReflect().Descriptor().FullName()), + "err", r.marshalErr, + ) + return + } + + r.marshaled = &RawResource{ + Name: r.Name, + Version: r.Version, + Resource: out, + Ttl: r.Ttl, + CacheControl: r.CacheControl, + Metadata: r.Metadata, + } + }) + return r.marshaled, r.marshalErr +} + +// TypeURL returns the underlying resource's type URL. +func (r *Resource[T]) TypeURL() string { + var t T + return types.APITypePrefix + string(t.ProtoReflect().Descriptor().FullName()) +} + +// UnmarshalRawResource unmarshals the given RawResource and returns a Resource of the corresponding +// type. Resource.Marshal on the returned Resource will return the given RawResource instead of +// re-serializing the resource. +func UnmarshalRawResource[T proto.Message](raw *RawResource) (*Resource[T], error) { + m, err := raw.Resource.UnmarshalNew() + if err != nil { + return nil, err + } + + r := &Resource[T]{ + Name: raw.Name, + Version: raw.Version, + Resource: m.(T), + Ttl: raw.Ttl, + CacheControl: raw.CacheControl, + Metadata: raw.Metadata, + } + // Set marshaled using marshalOnce, otherwise the once will not be set and subsequent calls to + // Marshal will serialize the resource, overwriting the field. + r.marshalOnce.Do(func() { + r.marshaled = raw + }) + + return r, nil +} + +const ( + // WildcardSubscription is a special resource name that triggers a subscription to all resources of a + // given type. + WildcardSubscription = "*" + // XDSTPScheme is the prefix for which all resource URNs (as defined in the [TP1 proposal]) start. + // + // [TP1 proposal]: https://github.com/cncf/xds/blob/main/proposals/TP1-xds-transport-next.md#uri-based-xds-resource-names + XDSTPScheme = "xdstp://" +) + +// A SubscriptionHandler will receive notifications for the cache entries it has subscribed to using +// RawCache.Subscribe. Note that it is imperative that implementations be hashable as it will be +// stored as the key to a map (unhashable types include slices and functions). +type SubscriptionHandler[T proto.Message] interface { + // Notify is invoked when the given entry is modified. A deletion is denoted with a nil resource. The given time + // parameters provides the time at which the client subscribed to the resource and the time at which the + // modification happened respectively. Note that if an entry is modified repeatedly at a high rate, Notify will not + // be invoked for all intermediate versions, though it will always *eventually* be invoked with the final version. + Notify(name string, r *Resource[T], metadata SubscriptionMetadata) +} + +// RawSubscriptionHandler is the untyped equivalent of SubscriptionHandler. +type RawSubscriptionHandler interface { + // Notify is the untyped equivalent of SubscriptionHandler.Notify. + Notify(name string, r *RawResource, metadata SubscriptionMetadata) + // ResourceMarshalError is invoked whenever a resource cannot be marshaled. This should be extremely + // rare and requires immediate attention. When a resource cannot be marshaled, the notification will + // be dropped and Notify will not be invoked. + ResourceMarshalError(name string, resource proto.Message, err error) +} + +// SubscriptionMetadata contains metadata about the subscription that triggered the Notify call on +// the [RawSubscriptionHandler] or [SubscriptionHandler]. +type SubscriptionMetadata struct { + // The time at which the resource was subscribed to + SubscribedAt time.Time + // The time at which the resource was modified (can be UnknownModifiedTime if the modification time is unknown) + ModifiedAt time.Time + // The time at which the update to the resource was received by the cache (i.e. when [Cache.Set] was + // called, not strictly when the server actually received the update). If this is metadata is for a + // subscription to a resource that does not yet exist, will be UnknownModifiedTime. + CachedAt time.Time + // The current priority index of the value. Will be 0 unless the backing cache was created with + // [NewPrioritizedCache], [NewPrioritizedAggregateCache] or + // [NewPrioritizedAggregateCachesByClientTypes]. If this metadata is for a subscription to a resource + // that has been deleted (or does not yet exist), Priority will be the last valid index priority + // index (because a resource is only considered deleted once it has been deleted from all cache + // sources). For example, if the cache was created like this: + // NewPrioritizedCache(10) + // Then the last valid index is 9, since the slice of cache objects returned is of length 10. + Priority int + // The glob collection this resource belongs to, empty if it does not belong to any collections. + GlobCollectionURL string +} + +// These aliases mirror the constants declared in [github.com/envoyproxy/go-control-plane/pkg/resource/v3] +type ( + Endpoint = endpoint.ClusterLoadAssignment + LbEndpoint = endpoint.LbEndpoint + Cluster = cluster.Cluster + Route = route.RouteConfiguration + ScopedRoute = route.ScopedRouteConfiguration + VirtualHost = route.VirtualHost + Listener = listener.Listener + Secret = tls.Secret + ExtensionConfig = core.TypedExtensionConfig + Runtime = runtime.Runtime +) + +// StreamType is an enum representing the different possible ADS stream types, SotW and Delta. +type StreamType int + +const ( + // UnknownStreamType is the 0-value, unknown stream type. + UnknownStreamType StreamType = iota + // DeltaStreamType is the delta/incremental variant of the ADS protocol. + DeltaStreamType + // SotWStreamType is the state-of-the-world variant of the ADS protocol. + SotWStreamType +) + +var streamTypeStrings = [...]string{"UNKNOWN", "Delta", "SotW"} + +func (t StreamType) String() string { + return streamTypeStrings[t] +} + +// StreamTypes is an array containing the valid [StreamType] values. +var StreamTypes = [...]StreamType{UnknownStreamType, DeltaStreamType, SotWStreamType} + +// LookupStreamTypeByRPCMethod checks whether the given RPC method string (usually acquired from +// [google.golang.org/grpc.StreamServerInfo.FullMethod] in the context of a server stream +// interceptor) is either [SotWStreamType] or [DeltaStreamType]. Returns ([UnknownStreamType], false) +// if it is neither. +func LookupStreamTypeByRPCMethod(rpcMethod string) (StreamType, bool) { + switch rpcMethod { + case "/envoy.service.discovery.v3.AggregatedDiscoveryService/StreamAggregatedResources": + return SotWStreamType, true + case "/envoy.service.discovery.v3.AggregatedDiscoveryService/DeltaAggregatedResources": + return DeltaStreamType, true + default: + return UnknownStreamType, false + } +} diff --git a/ads/glob_collection_url.go b/ads/glob_collection_url.go new file mode 100644 index 0000000..d501393 --- /dev/null +++ b/ads/glob_collection_url.go @@ -0,0 +1,162 @@ +package ads + +import ( + "errors" + "net/url" + "strings" + + types "github.com/envoyproxy/go-control-plane/pkg/resource/v3" +) + +// GlobCollectionURL represents the individual elements of a glob collection URL. Please refer to the +// [TP1 Proposal] for additional context on each field. In summary, a glob collection URL has the following format: +// +// xdstp://{Authority}/{ResourceType}/{Path}{?ContextParameters} +// +// [TP1 Proposal]: https://github.com/cncf/xds/blob/main/proposals/TP1-xds-transport-next.md#uri-based-xds-resource-names +type GlobCollectionURL struct { + // The URL's authority. Optional when URL of form "xdstp:///{ResourceType}/{Path}". + Authority string + // The type of the resources in the collection, without the "type.googleapis.com/" prefix. + ResourceType string + // The collection's path, without the trailing /* + Path string + // Optionally, the context parameters associated with the collection, always sorted by key name. If + // present, starts with "?". + ContextParameters string +} + +func (u GlobCollectionURL) String() string { + var path string + switch u.Path { + case "": + path = WildcardSubscription + case "/": + path = "/" + WildcardSubscription + default: + path = u.Path + "/" + WildcardSubscription + } + + return XDSTPScheme + + u.Authority + "/" + + u.ResourceType + "/" + + path + + u.ContextParameters +} + +// ErrInvalidGlobCollectionURI is always returned by the various glob collection URL parsing +// functions. +var ErrInvalidGlobCollectionURI = errors.New("diderot: invalid glob collection URI") + +// TODO: the functions in this file return non-specific errors to avoid additional allocations during +// cache updates, which can build up and get expensive. However this can be improved by having an +// error for each of the various ways a string can be an invalid glob collection URL. + +// ParseGlobCollectionURL attempts to parse the given name as GlobCollectionURL, returning an error +// if the given name does not represent one. See the [TP1 proposal] for additional context on the +// exact definition of a glob collection. +// +// [TP1 proposal]: https://github.com/cncf/xds/blob/main/proposals/TP1-xds-transport-next.md#uri-based-xds-resource-names +func ParseGlobCollectionURL(name, resourceType string) (GlobCollectionURL, error) { + gcURL, err := parseXDSTPURI(name, resourceType) + if err != nil { + return GlobCollectionURL{}, err + } + + var ok bool + gcURL.Path, ok = strings.CutSuffix(gcURL.Path, "/"+WildcardSubscription) + if !ok { + // URLs must end with /* + return GlobCollectionURL{}, ErrInvalidGlobCollectionURI + } + + return gcURL, nil +} + +// ExtractGlobCollectionURLFromResourceURN checks if the given name is a resource URN, and returns +// the corresponding GlobCollectionURL. The format of a resource URN is defined in the +// [TP1 proposal], and looks like this: +// +// xdstp://[{authority}]/{resource type}/{id/*}?{context parameters} +// +// For example: +// +// xdstp://some-authority/envoy.config.listener.v3.Listener/foo/bar/baz +// +// In the above example, the URN belongs to this collection: +// +// xdstp://authority/envoy.config.listener.v3.Listener/foo/bar/* +// +// Note that in the above example, the URN does _not_ belong to the following collection: +// +// xdstp://authority/envoy.config.listener.v3.Listener/foo/* +// +// Glob collections are not recursive, and the {id/?} segment of the URN (after the type) should be +// opaque, and not interpreted any further than the trailing /*. More details on this matter can be +// found [here]. +// +// This function returns an error if the given name is not a resource URN. +// +// [TP1 proposal]: https://github.com/cncf/xds/blob/main/proposals/TP1-xds-transport-next.md#uri-based-xds-resource-names +// [here]: https://github.com/cncf/xds/issues/91 +func ExtractGlobCollectionURLFromResourceURN(name, resourceType string) (GlobCollectionURL, error) { + gcURL, err := parseXDSTPURI(name, resourceType) + if err != nil { + return GlobCollectionURL{}, err + } + + lastSlash := strings.LastIndex(gcURL.Path, "/") + if lastSlash == -1 { + // Missing path in URL + return GlobCollectionURL{}, ErrInvalidGlobCollectionURI + } + + if gcURL.Path[lastSlash:] == "/"+WildcardSubscription { + // resource URN cannot end in /* + return GlobCollectionURL{}, ErrInvalidGlobCollectionURI + } + + if lastSlash == 0 { + gcURL.Path = "/" + } else { + gcURL.Path = gcURL.Path[:lastSlash] + } + + return gcURL, nil +} + +func parseXDSTPURI(resourceName, resourceType string) (GlobCollectionURL, error) { + // Skip deserializing the resource name if it doesn't start with the correct scheme + if !strings.HasPrefix(resourceName, XDSTPScheme) { + // doesn't start with xdstp:// + return GlobCollectionURL{}, ErrInvalidGlobCollectionURI + } + + parsedURL, err := url.Parse(resourceName) + if err != nil { + // invalid URL + return GlobCollectionURL{}, ErrInvalidGlobCollectionURI + } + + // Glob collection URLs do not start with the type prefix, so trim it here. + resourceType = strings.TrimPrefix(resourceType, types.APITypePrefix) + + collectionPath, ok := strings.CutPrefix(parsedURL.EscapedPath(), "/"+resourceType+"/") + if !ok { + // should include expected type after authority + return GlobCollectionURL{}, ErrInvalidGlobCollectionURI + } + + u := GlobCollectionURL{ + Authority: parsedURL.Host, + ResourceType: resourceType, + Path: collectionPath, + } + if len(parsedURL.RawQuery) > 0 { + // Using .Query() to parse the query then .Encode() to re-serialize ensures the query parameters are + // in the right sorted order. + u.ContextParameters = "?" + parsedURL.Query().Encode() + } + + return u, nil +} diff --git a/ads/glob_collection_url_test.go b/ads/glob_collection_url_test.go new file mode 100644 index 0000000..de928b3 --- /dev/null +++ b/ads/glob_collection_url_test.go @@ -0,0 +1,158 @@ +package ads + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +const ( + resourceType = "google.protobuf.Int64Value" +) + +func testBadURIs(t *testing.T, parser func(string, string) (GlobCollectionURL, error)) { + badURIs := []struct { + name string + resourceName string + }{ + { + name: "empty name", + resourceName: "", + }, + { + name: "invalid prefix", + resourceName: "https://foo/bar", + }, + { + name: "wrong type", + resourceName: "xdstp://auth/some.other.type/foo", + }, + { + name: "empty id", + resourceName: "xdstp://auth/google.protobuf.Int64Value", + }, + { + name: "empty id trailing slash", + resourceName: "xdstp://auth/google.protobuf.Int64Value/", + }, + { + name: "invalid query", + resourceName: "xdstp://auth/google.protobuf.Int64Value/foo?asd", + }, + } + + for _, test := range badURIs { + t.Run(test.name, func(t *testing.T) { + _, err := parser(test.resourceName, resourceType) + require.Error(t, err) + }) + } +} + +func testGoodURIs(t *testing.T, id string, parser func(string, string) (GlobCollectionURL, error)) { + tests := []struct { + name string + resourceName string + expected GlobCollectionURL + expectErr bool + }{ + { + name: "standard", + resourceName: "xdstp://auth/google.protobuf.Int64Value/foo/" + id, + expected: GlobCollectionURL{ + Authority: "auth", + ResourceType: resourceType, + Path: "foo", + ContextParameters: "", + }, + }, + { + name: "empty authority", + resourceName: "xdstp:///google.protobuf.Int64Value/foo/" + id, + expected: GlobCollectionURL{ + Authority: "", + ResourceType: resourceType, + Path: "foo", + ContextParameters: "", + }, + }, + { + name: "nested", + resourceName: "xdstp://auth/google.protobuf.Int64Value/foo/bar/baz/" + id, + expected: GlobCollectionURL{ + Authority: "auth", + ResourceType: resourceType, + Path: "foo/bar/baz", + ContextParameters: "", + }, + }, + { + name: "with query", + resourceName: "xdstp://auth/google.protobuf.Int64Value/foo/" + id + "?asd=123", + expected: GlobCollectionURL{ + Authority: "auth", + ResourceType: resourceType, + Path: "foo", + ContextParameters: "?asd=123", + }, + }, + { + name: "with unsorted query", + resourceName: "xdstp://auth/google.protobuf.Int64Value/foo/" + id + "?b=2&a=1", + expected: GlobCollectionURL{ + Authority: "auth", + ResourceType: resourceType, + Path: "foo", + ContextParameters: "?a=1&b=2", + }, + }, + { + name: "empty query", + resourceName: "xdstp://auth/google.protobuf.Int64Value/foo/" + id + "?", + expected: GlobCollectionURL{ + Authority: "auth", + ResourceType: resourceType, + Path: "foo", + ContextParameters: "", + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actual, err := parser(test.resourceName, resourceType) + if test.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, test.expected, actual) + } + }) + } +} + +func TestParseGlobCollectionURL(t *testing.T) { + t.Run("bad URIs", func(t *testing.T) { + testBadURIs(t, ParseGlobCollectionURL) + }) + t.Run("good URIs", func(t *testing.T) { + testGoodURIs(t, WildcardSubscription, ParseGlobCollectionURL) + }) + t.Run("rejects URNs", func(t *testing.T) { + _, err := ParseGlobCollectionURL("xdstp:///"+resourceType+"/foo/bar", resourceType) + require.Error(t, err) + }) +} + +func TestExtractGlobCollectionURLFromResourceURN(t *testing.T) { + t.Run("bad URIs", func(t *testing.T) { + testBadURIs(t, ExtractGlobCollectionURLFromResourceURN) + }) + t.Run("good URIs", func(t *testing.T) { + testGoodURIs(t, "foo", ExtractGlobCollectionURLFromResourceURN) + }) + t.Run("rejects glob collection URLs", func(t *testing.T) { + _, err := ExtractGlobCollectionURLFromResourceURN("xdstp:///"+resourceType+"/foo/*", resourceType) + require.Error(t, err) + }) +} diff --git a/cache.go b/cache.go new file mode 100644 index 0000000..288af19 --- /dev/null +++ b/cache.go @@ -0,0 +1,377 @@ +package diderot + +import ( + "fmt" + "time" + + "github.com/linkedin/diderot/ads" + internal "github.com/linkedin/diderot/internal/cache" + "google.golang.org/protobuf/proto" +) + +// Cache is the primary type provided by this package. It provides an efficient storage mechanism for +// ads.RawResource objects, and the means to subscribe to them via the SubscriptionHandler interface. +// For example, it can be used to store the set of "envoy.config.listener.v3.Listener" available to +// clients. +type Cache[T proto.Message] interface { + RawCache + // Set stores the given resource in the cache. If the resource name corresponds to a resource URN, it + // will also be stored in the corresponding glob collection (see [TP1 proposal] for additional + // details on the format). See Subscribe for more details on how the resources added by this method + // can be subscribed to. Invoking Set whenever possible is preferred to RawCache.SetRaw, since it can + // return an error if the given resource's type does not match the expected type while Set validates + // at compile time that the given value matches the desired type. A zero [time.Time] can be used to + // represent that the time at which the resource was created or modified is unknown (or ignored). + // + // WARNING: It is imperative that the Resource and the underlying [proto.Message] not be modified + // after insertion! This resource will be read by subscribers to the cache and callers of Get, and + // modifying the resource may at best result in incorrect reads for consumers and at worst panics if + // the consumer is reading a map as it's being modified. When in doubt, callers should pass in a deep + // copy of the resource. Note that the cache takes no responsibility in enforcing this since cloning + // every resource as it is inserted in the cache may incur unexpected and avoidable costs. + // + // [TP1 proposal]: https://github.com/cncf/xds/blob/main/proposals/TP1-xds-transport-next.md#uri-based-xds-resource-names + Set(name, version string, t T, modifiedAt time.Time) *ads.Resource[T] + // SetResource is the more verbose equivalent of Set which supports the additional fields in [ads.Resource]. + SetResource(r *ads.Resource[T], modifiedAt time.Time) + // Get fetches the entry, or nil if it's not present and/or has been deleted. + Get(name string) *ads.Resource[T] + // IsSubscribedTo checks whether the given handler is subscribed to the given named entry. + IsSubscribedTo(name string, handler ads.SubscriptionHandler[T]) bool + // Subscribe registers the handler as a subscriber of the given named resource. The handler is always + // immediately called with the current values of the entries selected by this call, even if it was + // already subscribed. + // + // If the name is ads.WildcardSubscription, the handler is registered as a wildcard subscriber. This + // means the handler will be subscribed to all existing entries, and be automatically subscribed to + // any new entries until a corresponding call to Unsubscribe is made. + // + // If the name is a glob collection URL, the handler will be subscribed to all entries in the + // collection, along with being automatically subscribed to any new entries. If the collection is + // empty, the handler will receive a deletion notification for the entire collection. This behavior + // is defined in the [TP1 proposal]: + // If no resources are present in the glob collection, the server should reply with a + // DeltaDiscoveryResponse in which the glob collection URL is specified in removed_resources. + // The subscription will be preserved even if the glob collection is empty (or becomes empty) until a + // corresponding call to Unsubscribe is made. + // + // Otherwise, the handler will be subscribed to the resource specified by the given name and receive + // notifications any time the resource changes. If a resource by that name does not exist, the + // handler will immediately receive a deletion notification, but will not be unsubscribed until a + // corresponding call to Unsubscribe is made. See the [spec on deletions] for more details. + // + // Note that there are therefore three ways to subscribe to a given resource: + // 1. The simplest way is to explicitly subscribe to a resource, via its name. Such a subscription is + // it can only be cancelled with a corresponding call to Unsubscribe. It will not, for example, be + // cancelled by unsubscribing from the wildcard. This is by design, as it allows clients to discover + // resources by emitting a wildcard subscription, finding which resources they are interested in, + // explicitly subscribing to those then removing the implicit subscriptions to other resources by + // unsubscribing from the wildcard. This is outlined in the [sample xDS flows]. + // 2. If the resource's name is a URN, a subscription to the matching glob collection URL will + // subscribe the given handler to the resource. Similar to the explicit subscription listed in 1., + // unsubscribing from the wildcard will not cancel a glob collection to a resource, only a + // corresponding unsubscription to the collection will cancel it. + // 3. A wildcard subscription will also implicitly create a subscription to the resource. + // subscribe + // + // Note that while the xDS docs are clear on what the behavior should be when a subscription is + // "upgraded" from a wildcard subscription to an explicit subscription, they are not clear as to what + // happens when a subscription is "downgraded". For example, if a client subscribes to a resource "A" + // then subscribes to the wildcard, should an unsubscription from the wildcard cancel the + // subscription to "A"? Similarly, the docs are unclear as to what should happen if a client + // subscribes to the wildcard, then subscribes to resource "A", then unsubscribes from "A". Should + // the original implicit subscription to "A" via the wildcard be honored? To address both of these, + // the cache will preserve all subscriptions that target a specific resource. This means a client that + // subscribed to a resource both via a wildcard and an explicit subscription (regardless of order) will + // only be unsubscribed from that resource once it has both explicitly unsubscribed from the resource and + // unsubscribed from the wildcard (regardless of order). + // + // It is unsafe for multiple goroutines to invoke Subscribe and/or Unsubscribe with the same + // SubscriptionHandler, and will result undefined behavior. + // + // [TP1 proposal]: https://github.com/cncf/xds/blob/main/proposals/TP1-xds-transport-next.md#glob + // [sample xDS flows]: https://www.envoyproxy.io/docs/envoy/latest/api-docs/xds_protocol#how-the-client-specifies-what-resources-to-return + // [spec on deletions]: https://www.envoyproxy.io/docs/envoy/latest/api-docs/xds_protocol#id2 + Subscribe(name string, handler ads.SubscriptionHandler[T]) + // Unsubscribe removes the given handler from the named entry's list of subscribers. + // + // If the given name is ads.WildcardSubscription, the handler is unsubscribed from all entries it did + // not explicitly subscribe to (see definition of explicit subscription in Subscribe). + // + // If the given name is a glob collection URL, it is unsubscribed from the collection, unsubscribing + // it from all matching entries. + // + // Noop if the resource does not exist or the handler was not subscribed to it. + Unsubscribe(name string, handler ads.SubscriptionHandler[T]) +} + +// RawCache is a subset of the [Cache] interface and provides a number of methods to interact with +// the [Cache] without needing to know the underlying resource type at compile time. All RawCache +// implementations *must* also implement [Cache] for the underlying resource type. +type RawCache interface { + // Type returns the corresponding [Type] for this cache. + Type() Type + // EntryNames invokes the given function for all the current entry names in the cache. If the function returns + // false, the iteration stops. The entries are iterated over in random order. + EntryNames(f func(name string) bool) + // GetRaw is the untyped equivalent of Cache.Get. There are uses for this method, but the preferred + // way is to use Cache.Get because this function incurs the cost of marshaling the resource. Returns + // an error if the resource cannot be marshaled. + GetRaw(name string) (*ads.RawResource, error) + // SetRaw is the untyped equivalent of Cache.Set. There are uses for this method, but the preferred + // way is to use Cache.Set since it offers a typed API instead of the untyped ads.RawResource parameter. + // Subscribers will be notified of the new version of this resource. See Cache.Set for additional + // details on how the resources are stored. Returns an error if the given resource's type URL does + // not match the expected type URL, or the resource cannot be unmarshaled. + SetRaw(r *ads.RawResource, modifiedAt time.Time) error + // Clear clears the entry (if present) and notifies all subscribers that the entry has been deleted. + // A zero [time.Time] can be used to represent that the time at which the resource was cleared is + // unknown (or ignored). For example, when watching a directory, the filesystem does not keep track + // of when the file was deleted. + Clear(name string, clearedAt time.Time) +} + +// NewCache returns a simple Cache with only 1 priority (see NewPrioritizedCache). +func NewCache[T proto.Message]() Cache[T] { + return NewPrioritizedCache[T](1)[0] +} + +// NewPrioritizedCache creates a series of Cache accessors that all point to the same underlying +// cache, but have different "priorities". The Cache object that appears first in the returned slice +// has the highest priority, with every subsequent Cache having correspondingly lower priority. If +// the same resource is provided by two Caches, the resource defined by the Cache with the highest +// priority will be provided to subscribers and returned by Cache.GetResource. Conversely, if a Cache +// with a high priority clears a resource, the underlying cache will fall back to lower priority +// definitions if present. A resource is only fully cleared if it is cleared at all priority levels. +// +// Concretely, this feature is intended to be used when a resource definition can come from multiple +// sources. For example, if resource definitions are being migrated from one source to another, it +// would be sane to always use the new source if it is present, otherwise fall back to the old +// source. This would be as opposed to simply picking whichever source defined the resource most +// recently, as it would mean the resource definition is nondeterministic. +func NewPrioritizedCache[T proto.Message](prioritySlots int) []Cache[T] { + c := newCache[T](prioritySlots) + caches := make([]Cache[T], prioritySlots) + for i := range caches { + caches[i] = newCacheWithPriority[T](c, internal.Priority(i)) + } + + return caches +} + +func newCache[T proto.Message](prioritySlots int) *cache[T] { + ref := TypeOf[T]() + return &cache[T]{ + typeReference: ref, + trimmedTypeURL: ref.TrimmedURL(), + prioritySlots: prioritySlots, + } +} + +// A cache implements a data structure that allows storing and subscribing to xDS objects. It's expected that there will +// be tens of thousands of cache readers each subscribed to hundreds of resources, making the cache particularly +// read-heavy. Under such load, it is preferable to have more work occur on each write to alleviate the work that needs +// to be done on each read since one write can, at worst, be multiplied into hundreds of thousands of reads. As such, +// the cache is based on a subscription model (via cache.Subscribe) that minimizes reader overhead. Instead of reading +// from the backing map every time, subscribers subscribe directly to updates on the backing value. Writers call +// cache.Set to notify all active subscribers. +type cache[T proto.Message] struct { + // This is the type of each resource in this cache. Set and SetResource guarantee that all insertions + // in this cache satisfy this invariant. + typeReference TypeReference[T] + // The typeURL of the resources in this cache, without the leading "type.googleapis.com/". Used for + // resource URNs which do not include this prefix. + trimmedTypeURL string + // This resourceMap maps the resource's name to its corresponding WatchableValue. + resources internal.ResourceMap[string, *internal.WatchableValue[T]] + // The number of slots watchableValue instances should be created with (see NewPrioritizedCache for + // details on the cache priority). + prioritySlots int + // The set of wildcard subscribers that should be automatically subscribed to any new entries. + wildcardSubscribers internal.SubscriberSet[T] + // This secondary data structure is updated any time a resource that belongs to a glob collection is + // added or removed from the map. Resources belong to glob collections if their name is a xdstp URN + // (see ExtractGlobCollectionURLFromResourceURN). + globCollections internal.GlobCollectionsMap[T] +} + +func (c *cache[T]) Type() Type { + return c.typeReference +} + +func (c *cache[T]) IsSubscribedTo(name string, handler ads.SubscriptionHandler[T]) (subscribed bool) { + if c.wildcardSubscribers.IsSubscribed(handler) { + return true + } + + if gcURL, err := ads.ParseGlobCollectionURL(name, c.trimmedTypeURL); err == nil { + return c.globCollections.IsSubscribed(gcURL, handler) + } + + c.resources.ComputeIfPresent(name, func(name string, value *internal.WatchableValue[T]) { + subscribed = value.IsSubscribed(handler) + }) + + return subscribed +} + +func (c *cache[T]) Subscribe(name string, handler ads.SubscriptionHandler[T]) { + if name == ads.WildcardSubscription { + subscribedAt, version := c.wildcardSubscribers.Subscribe(handler) + c.EntryNames(func(name string) bool { + // Cannot call c.Subscribe here because it always creates a backing watchableValue if it does not + // already exist. For wildcard subscriptions, if the entry doesn't exist (or in this case has been + // deleted), a subscription isn't necessary. If the entry reappears, it will be automatically + // subscribed to. + c.resources.ComputeIfPresent(name, func(name string, value *internal.WatchableValue[T]) { + value.NotifyHandlerAfterSubscription(handler, internal.WildcardSubscription, subscribedAt, version) + }) + return true + }) + } else if gcURL, err := ads.ParseGlobCollectionURL(name, c.trimmedTypeURL); err == nil { + c.globCollections.Subscribe(gcURL, handler) + } else { + c.createOrModifyEntry(name, func(name string, value *internal.WatchableValue[T]) { + value.Subscribe(handler) + }) + } +} + +// createOrModifyEntry executes the given function on the value of that name after ensuring that it exists in the map. +func (c *cache[T]) createOrModifyEntry(name string, f func(name string, value *internal.WatchableValue[T])) { + c.resources.Compute( + name, + func(name string) *internal.WatchableValue[T] { + v := internal.NewValue[T](name, c.prioritySlots) + v.SubscriberSets[internal.WildcardSubscription] = &c.wildcardSubscribers + + if gcURL, err := ads.ExtractGlobCollectionURLFromResourceURN(name, c.trimmedTypeURL); err == nil { + c.globCollections.PutValueInCollection(gcURL, v) + } + + return v + }, + f, + ) +} + +// deleteEntryIfNilAndNoSubscribers attempts to delete the entry of that name from the map, if it exists. If the entry +// exists, it grabs the write lock, deletes the entry from the map then closes the watchableValue.newValue channel, +// signaling to the notification goroutine that this entry will not be updated anymore. +func (c *cache[T]) deleteEntryIfNilAndNoSubscribers(name string) { + c.resources.DeleteIf(name, func(name string, value *internal.WatchableValue[T]) bool { + hasNoExplicitSubscribers := value.SubscriberSets[internal.ExplicitSubscription].Size() == 0 + if value.Read() == nil && hasNoExplicitSubscribers { + if gcURL, err := ads.ExtractGlobCollectionURLFromResourceURN(name, c.trimmedTypeURL); err == nil { + c.globCollections.RemoveValueFromCollection(gcURL, value) + } + return true + } + // It's possible that between releasing the read lock and acquiring the write lock, the entry was either + // resubscribed to or set to a non-nil value, in which case it is no longer eligible for deletion. + return false + }) +} + +// unsubscribe implements actually unsubscribing the given handler from the value of that name (if it exists). If +// onlyIfWildcard is true, the handler will only be unsubscribed if its subscription is denoted as a wildcard +// subscription in the backing watchableValue (see Cache.DisableWildcardSubscription for more details on why this +// exists) +func (c *cache[T]) unsubscribe(name string, handler ads.SubscriptionHandler[T]) { + var shouldDelete bool + c.resources.ComputeIfPresent(name, func(name string, value *internal.WatchableValue[T]) { + hasNoExplicitSubscribers := value.Unsubscribe(handler) + shouldDelete = hasNoExplicitSubscribers && value.Read() == nil + }) + if shouldDelete { + c.deleteEntryIfNilAndNoSubscribers(name) + } +} + +func (c *cache[T]) Unsubscribe(name string, handler ads.SubscriptionHandler[T]) { + if name == ads.WildcardSubscription { + c.wildcardSubscribers.Unsubscribe(handler) + } else if gcURL, err := ads.ParseGlobCollectionURL(name, c.trimmedTypeURL); err == nil { + c.globCollections.Unsubscribe(gcURL, handler) + } else { + c.unsubscribe(name, handler) + } +} + +func (c *cache[T]) Get(name string) (r *ads.Resource[T]) { + c.resources.ComputeIfPresent(name, func(name string, value *internal.WatchableValue[T]) { + r = value.Read() + }) + return r +} + +func (c *cache[T]) GetRaw(name string) (*ads.RawResource, error) { + r := c.Get(name) + if r == nil { + return nil, nil + } + return r.Marshal() +} + +func (c *cache[T]) EntryNames(f func(name string) bool) { + c.resources.Keys(f) +} + +var _ Cache[proto.Message] = (*cacheWithPriority[proto.Message])(nil) + +func newCacheWithPriority[T proto.Message](c *cache[T], p internal.Priority) *cacheWithPriority[T] { + return &cacheWithPriority[T]{cache: c, p: p} +} + +// cacheWithPriority holds a reference to an underlying cache along with a specific priority index. +// It is the only implementation of Cache. Whenever the SetEntry, SetResource or ClearEntry methods +// are invoked, it invokes the respective watchableValue.set or watchableValue.clear methods with the +// priority index. This way, each source gets its own Cache reference that has a built-in priority +// index, instead of being required to explicitly specify the index, which is error-prone and could +// lead to unexpected behavior. +type cacheWithPriority[T proto.Message] struct { + *cache[T] + p internal.Priority +} + +func (c *cacheWithPriority[T]) Clear(name string, clearedAt time.Time) { + var shouldDelete bool + c.resources.ComputeIfPresent(name, func(name string, value *internal.WatchableValue[T]) { + shouldDelete = value.Clear(c.p, clearedAt) && value.SubscriberSets[internal.ExplicitSubscription].Size() == 0 + }) + if shouldDelete { + c.deleteEntryIfNilAndNoSubscribers(name) + } +} + +func (c *cacheWithPriority[T]) Set(name, version string, t T, modifiedAt time.Time) *ads.Resource[T] { + r := &ads.Resource[T]{ + Name: name, + Version: version, + Resource: t, + } + c.SetResource(r, modifiedAt) + return r +} + +func (c *cacheWithPriority[T]) SetResource(r *ads.Resource[T], modifiedAt time.Time) { + c.createOrModifyEntry(r.Name, func(name string, value *internal.WatchableValue[T]) { + value.Set(c.p, r, modifiedAt) + }) +} + +func (c *cacheWithPriority[T]) SetRaw(raw *ads.RawResource, modifiedAt time.Time) error { + // Ensure that the given resource's type URL is correct. + if u := raw.GetResource().GetTypeUrl(); u != c.typeReference.URL() { + return fmt.Errorf("diderot: invalid type URL, expected %q got %q", c.typeReference, u) + } + + r, err := ads.UnmarshalRawResource[T](raw) + if err != nil { + return err + } + + c.SetResource(r, modifiedAt) + + return nil +} diff --git a/cache_test.go b/cache_test.go new file mode 100644 index 0000000..e9a32d8 --- /dev/null +++ b/cache_test.go @@ -0,0 +1,794 @@ +package diderot_test + +import ( + "fmt" + "maps" + "sort" + "strconv" + "sync" + "testing" + "time" + + "github.com/linkedin/diderot" + "github.com/linkedin/diderot/ads" + internal "github.com/linkedin/diderot/internal/cache" + "github.com/linkedin/diderot/internal/utils" + "github.com/linkedin/diderot/testutils" + "github.com/stretchr/testify/require" + . "google.golang.org/protobuf/types/known/timestamppb" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +func newCache() diderot.Cache[*Timestamp] { + return diderot.NewCache[*Timestamp]() +} + +func newResource(name, version string) *ads.Resource[*Timestamp] { + return ads.NewResource(name, version, Now()) +} + +const ( + name1 = "r1" + name2 = "r2" + name3 = "r3" +) + +var noTime time.Time + +func TestCacheCrud(t *testing.T) { + c := newCache() + + require.Nil(t, c.Get(name1)) + + version := "1" + + r1 := c.Set(name1, version, Now(), noTime) + require.Same(t, r1, c.Get(name1)) + + c.Clear(name1, noTime) + require.Nil(t, c.Get(name1)) + + c.SetResource(r1, noTime) + + checkEntries := func(expected ...string) { + var entries []string + c.EntryNames(func(name string) bool { + entries = append(entries, name) + return true + }) + sort.Strings(entries) + sort.Strings(expected) + require.Equal(t, expected, entries) + } + + checkEntries(name1) + c.Set(name2, version, Now(), noTime) + checkEntries(name1, name2) +} + +func TestCacheSubscribe(t *testing.T) { + c := newCache() + + updates := make(testutils.ChanSubscriptionHandler[*Timestamp], 1) + wildcard := make(testutils.ChanSubscriptionHandler[*Timestamp], 1) + c.Subscribe(ads.WildcardSubscription, wildcard) + + // Subscribe should always feed the initial value of the resource to the stream, even if it does not currently exist + c.Subscribe(name1, updates) + updates.WaitForDelete(t, name1) + select { + case <-updates: + t.Fatalf("Received unexpected update for %q", name1) + case <-time.After(10 * time.Millisecond): + } + + // Test that explicitly re-subscribing to the cache even if the entry doesn't exist delivers the + // notification again. + c.Subscribe(name1, updates) + updates.WaitForDelete(t, name1) + + // conversely, creating the nil entry should _not_ update the wildcard subscriber until the entry is actually + // updated + select { + case <-wildcard: + t.Fatalf("Received update for %q even though entry does not actually exist", name1) + case <-time.After(10 * time.Millisecond): + } + + r2 := c.Set(name2, "2", Now(), noTime) + c.Subscribe(name2, updates) + updates.WaitForUpdate(t, r2) + wildcard.WaitForUpdate(t, r2) + + // Test that explicitly re-subscribing to the cache delivers the resource again. + c.Subscribe(name2, updates) + updates.WaitForUpdate(t, r2) + + r1 := c.Set(name1, "1", Now(), noTime) + updates.WaitForUpdate(t, r1) + wildcard.WaitForUpdate(t, r1) + + c.Clear(name1, noTime) + updates.WaitForDelete(t, name1) + wildcard.WaitForDelete(t, name1) + + // Test that the cache ignores double deletes by re-clearing r1, then updating it. We should only see the update + c.Clear(name1, noTime) + + c.Unsubscribe(name1, updates) + + c.Clear(name2, noTime) + // Because the cache updates are happening from this thread and all the methods return only after having pushed the + // updates to all the subscribers, we know that if we receive the update for name2 on the stream then the cache + // correctly unsubscribed the stream from r1. + updates.WaitForDelete(t, name2) + wildcard.WaitForDelete(t, name2) + c.Unsubscribe(name1, updates) + c.Unsubscribe(name2, updates) + + c.EntryNames(func(name string) bool { + t.Fatal("Cache should be empty!") + return true + }) + + r1 = c.Set(name1, "3", Now(), noTime) + wildcard.WaitForUpdate(t, r1) + r2 = c.Set(name2, "4", Now(), noTime) + wildcard.WaitForUpdate(t, r2) + + // Ensure a double wildcard subscription notifies of all existing values (note that this needs to run a separate + // routine since it'll push 2 values onto the channel, which only has a capacity of 1) + go c.Subscribe(ads.WildcardSubscription, wildcard) + for i := 0; i < 2; i++ { + r := <-wildcard + if r.Name == name1 { + require.Same(t, r1, r.Resource) + } else { + require.Same(t, r2, r.Resource) + } + } + + // Explicit subscription to r1 for wildcard + c.Subscribe(name1, wildcard) + // An explicit subscription will always yield the current value, even if the handler has already seen it + wildcard.WaitForUpdate(t, r1) + + // Remove remaining wildcard subscriptions + c.Unsubscribe(ads.WildcardSubscription, wildcard) + r1 = c.Set(name1, "5", Now(), noTime) + // Check that the old subscription was preserved + wildcard.WaitForUpdate(t, r1) + + // but that the handler is not automatically added to new entries + c.IsSubscribedTo(name3, wildcard) + r3 := c.Set(name3, "6", Now(), noTime) + select { + case <-wildcard: + t.Fatalf("Received update for %v even though there was no subscription for it", r3) + case <-time.After(10 * time.Millisecond): + } + // Nor old entries + c.IsSubscribedTo(name2, wildcard) + c.Set(name2, "7", Now(), noTime) + select { + case <-wildcard: + t.Fatalf("Received update for %q even though there was no subscription for it", name2) + case <-time.After(10 * time.Millisecond): + } +} + +func TestWildcardSubscriptionOnNonEmptyCache(t *testing.T) { + testutils.WithTimeout(t, "real entries", 5*time.Second, func(t *testing.T) { + c := newCache() + + r1 := c.Set(name1, "1", Now(), noTime) + r2 := c.Set(name2, "2", Now(), noTime) + + wildcard := make(testutils.ChanSubscriptionHandler[*Timestamp]) + go c.Subscribe(ads.WildcardSubscription, wildcard) + + remaining := utils.NewSet(name1, name2) + for r := range wildcard { + require.True(t, remaining.Remove(r.Name)) + switch r.Name { + case name1: + require.Same(t, r1, r.Resource) + case name2: + require.Same(t, r2, r.Resource) + } + if len(remaining) == 0 { + break + } + } + }) + // This tests that entries that were created by a call to Subscribe but never set (aka "fake" entries) are never + // shown to a wildcard subscriber + testutils.WithTimeout(t, "fake entries", 5*time.Second, func(t *testing.T) { + c := newCache() + + h := testutils.NewSubscriptionHandler[*Timestamp]( + func(name string, r *ads.Resource[*Timestamp], _ ads.SubscriptionMetadata) { + require.Nil(t, r) + }, + ) + c.Subscribe(name1, h) + c.Subscribe(name2, h) + + wildcard := testutils.NewSubscriptionHandler[*Timestamp]( + func(string, *ads.Resource[*Timestamp], ads.SubscriptionMetadata) { + require.Fail(t, "Wildcard handler should not be called") + }, + ) + for i := 0; i < 10; i++ { + c.Subscribe(ads.WildcardSubscription, wildcard) + } + }) + // Tests that once an entry has been deleted, it is not shown to a wildcard subscriber, even if it wildcard + // subscribes again. + testutils.WithTimeout(t, "deleted entries", 5*time.Second, func(t *testing.T) { + c := newCache() + + notified := make(chan struct{}, 1) + wildcard := testutils.NewSubscriptionHandler[*Timestamp]( + func(_ string, r *ads.Resource[*Timestamp], _ ads.SubscriptionMetadata) { + if r != nil { + notified <- struct{}{} + } else { + close(notified) + } + }, + ) + + c.Subscribe(ads.WildcardSubscription, wildcard) + + c.Set(name1, "0", Now(), noTime) + <-notified + c.Clear(name1, noTime) + <-notified + for i := 0; i < 10; i++ { + c.Subscribe(ads.WildcardSubscription, wildcard) + <-notified + } + }) +} + +func TestCachePriority(t *testing.T) { + c := diderot.NewPrioritizedCache[*Timestamp](4) + + handlers := make([]testutils.ChanSubscriptionHandler[*Timestamp], len(c)) + for i := range handlers { + handlers[i] = make(testutils.ChanSubscriptionHandler[*Timestamp], 1) + c[i].Subscribe(name1, handlers[i]) + notification := handlers[i].WaitForDelete(t, name1) + require.Equal(t, len(c)-1, notification.Metadata.Priority) + } + + resources := make([]*ads.Resource[*Timestamp], 5) + for i := range resources { + resources[i] = newResource(name1, "0") + } + + for i := 2; i >= 0; i-- { + c[i].SetResource(resources[i], noTime) + require.Same(t, resources[i], c[i].Get(name1)) + for _, h := range handlers { + notification := h.WaitForUpdate(t, resources[i]) + require.Equal(t, i, notification.Metadata.Priority) + } + } + + // This should be ignored + c[2].SetResource(resources[3], noTime) + for _, c := range c { + require.Same(t, resources[0], c.Get(name1)) + } + + // This should also be ignored + c[2].Clear(name1, noTime) + for _, c := range c { + require.Same(t, resources[0], c.Get(name1)) + } + + // This should bring us back to resources[1] + c[0].Clear(name1, noTime) + for i, h := range handlers { + require.Same(t, resources[1], c[i].Get(name1)) + notification := h.WaitForUpdate(t, resources[1]) + require.Equal(t, 1, notification.Metadata.Priority) + } + + // This should fully delete the resource + c[1].Clear(name1, noTime) + for i, h := range handlers { + require.Nil(t, c[i].Get(name1)) + h.WaitForDelete(t, name1) + } +} + +func TestNotifyMetadata(t *testing.T) { + c := diderot.NewCache[*Timestamp]() + h := make(testutils.ChanSubscriptionHandler[*Timestamp], 1) + + subscribedAtStart := time.Now() + c.Subscribe(name1, h) + subscribedAtEnd := time.Now() + + notification := h.WaitForDelete(t, name1) + // Since go provides no convenient way to mock the system clock, the only way to test the various timestamps + // is to check that they occur within the start and end of the invocation of the function being tested. This is + // preferable to require.WithinDuration, which is inherently flaky as it requires assuming that whatever is being + // measured happens within a predetermined duration (e.g. event "b" happened at most 1ms after "a"). + require.WithinRange(t, notification.Metadata.SubscribedAt, subscribedAtStart, subscribedAtEnd) + require.Equal(t, noTime, notification.Metadata.ModifiedAt) + require.Equal(t, noTime, notification.Metadata.CachedAt) + + for i := 0; i < 10; i++ { + modifiedAt := time.Now() + cachedAtStart := time.Now() + r := c.Set(name1, "0", Now(), modifiedAt) + cachedAtEnd := time.Now() + + notification = h.WaitForUpdate(t, r) + require.Equal(t, notification.Metadata.ModifiedAt, modifiedAt) + require.WithinRange(t, notification.Metadata.CachedAt, cachedAtStart, cachedAtEnd) + } +} + +// TestWatchableValueUpdateCancel tests a very specific edge case where an entry is updated during the subscription +// loop. The loop should abort and not call the remaining subscribers. It should instead restart and run through each +// subscriber with the updated value. +func TestWatchableValueUpdateCancel(t *testing.T) { + c := newCache() + + r1 := newResource(name1, "0") + + var r1Wg, r2Wg sync.WaitGroup + r1Wg.Add(1) + r2Wg.Add(2) + notify := func(name string, r *ads.Resource[*Timestamp], _ ads.SubscriptionMetadata) { + // r is nil during the initial invocation of notify since the resource does not yet exist. + if r == nil { + return + } + + if r == r1 { + c.Set(name1, "1", Now(), noTime) + // if notify is invoked with r1 more than once, this will panic + r1Wg.Done() + } else { + r2Wg.Done() + } + } + + c.Subscribe(name1, testutils.NewSubscriptionHandler(notify)) + c.Subscribe(name1, testutils.NewSubscriptionHandler(notify)) + + c.SetResource(r1, noTime) + + r1Wg.Wait() + r2Wg.Wait() +} + +// TestCacheEntryDeletion specifically checks for entry deletion. There is no explicit way to check whether an entry is +// still in the cache, but it can be implicitly checked by calling EntryNames +func TestCacheEntryDeletion(t *testing.T) { + h := make(testutils.ChanSubscriptionHandler[*Timestamp], 1) + + inCache := func(c diderot.Cache[*Timestamp]) bool { + inCache := false + c.EntryNames(func(name string) bool { + if name == name1 { + inCache = true + return false + } + return true + }) + return inCache + } + checkEntryExists := func(t *testing.T, c diderot.Cache[*Timestamp]) { + require.Truef(t, inCache(c), "%q not in cache!", name1) + } + checkEntryDoesNotExist := func(t *testing.T, c diderot.Cache[*Timestamp]) { + require.Falsef(t, inCache(c), "%q in cache!", name1) + } + + setup := func(t *testing.T) diderot.Cache[*Timestamp] { + c := newCache() + + c.Subscribe(name1, h) + + h.WaitForDelete(t, name1) + + checkEntryExists(t, c) + + return c + } + + // In this test, the value is nil. Unsubscribing from it should delete it + t.Run("delete on last unsub", func(t *testing.T) { + c := setup(t) + wildcard := testutils.NewSubscriptionHandler[*Timestamp]( + func(name string, r *ads.Resource[*Timestamp], _ ads.SubscriptionMetadata) { + t.Fatalf("wildcard handler not expected to ever be called (name=%q, r=%v)", name, r) + }, + ) + c.Subscribe(ads.WildcardSubscription, wildcard) + c.Unsubscribe(name1, h) + // At this point, even though there is a wildcard subscription to the entry, it can still be safely deleted + // because the wildcard entry will be added back automatically. The behavior of the wildcard subscriptions is + // tested elsewhere, so it's not necessary to retest it here. + checkEntryDoesNotExist(t, c) + }) + + // In this test, the value is set, it should not get automatically deleted until cleared + t.Run("clear deletes entry", func(t *testing.T) { + c := setup(t) + // Explicitly set the entry + c.Set(name1, "0", Now(), noTime) + // Remove the subscription from it + c.Unsubscribe(name1, h) + checkEntryExists(t, c) + c.Clear(name1, noTime) + checkEntryDoesNotExist(t, c) + }) +} + +func TestCacheCollections(t *testing.T) { + c := diderot.NewCache[*Timestamp]() + + const prefix = "xdstp:///google.protobuf.Timestamp/" + + h := make(testutils.ChanSubscriptionHandler[*Timestamp], 1) + + c.Subscribe(prefix+"a/*", h) + h.WaitForDelete(t, prefix+"a/*") + + c.Subscribe(prefix+"a/foo", h) + h.WaitForDelete(t, prefix+"a/foo") + + var updates []testutils.ExpectedNotification[*Timestamp] + var deletes []testutils.ExpectedNotification[*Timestamp] + for i := 0; i < 5; i++ { + name, v := prefix+"a/"+strconv.Itoa(i), strconv.Itoa(i) + updates = append(updates, testutils.ExpectUpdate(c.Set(name, v, Now(), noTime))) + deletes = append(deletes, testutils.ExpectDelete[*Timestamp](name)) + } + + h.WaitForNotifications(t, updates...) + + for _, d := range deletes { + c.Clear(d.Name, noTime) + } + + h.WaitForNotifications(t, deletes...) + + h.WaitForDelete(t, prefix+"a/*") +} + +// TestCache raw validates that the various *Raw methods on the cache work as expected. Namely, raw +// subscribers are wrapped in a wrappedHandler, which is then used as the key in the subscriber map. +// It's important to check that subscribing then unsubscribing works as expected. +func TestCacheRaw(t *testing.T) { + c := diderot.RawCache(newCache()) + r := newResource(name1, "42") + + ch := make(chan *ads.RawResource, 1) + h := testutils.NewRawSubscriptionHandler( + t, + func(name string, raw *ads.RawResource, metadata ads.SubscriptionMetadata) { + if raw != nil { + testutils.ProtoEquals(t, testutils.MustMarshal(t, r), raw) + } + ch <- raw + }, + ) + + diderot.Subscribe(c, name1, h) + <-ch + require.NoError(t, c.SetRaw(testutils.MustMarshal(t, r), noTime)) + raw, err := c.GetRaw(r.Name) + require.NoError(t, err) + require.Same(t, testutils.MustMarshal(t, r), raw) + <-ch + c.Clear(name1, noTime) + <-ch + diderot.Unsubscribe(c, name1, h) + select { + case raw := <-ch: + require.Fail(t, "Received unexpected update after unsubscription", raw) + case <-time.After(50 * time.Millisecond): + } +} + +func TestSetThenSubscribe(t *testing.T) { + var called sync.WaitGroup + h := testutils.NewSubscriptionHandler(func(string, *ads.Resource[*wrapperspb.BoolValue], ads.SubscriptionMetadata) { + called.Done() + }) + + c := diderot.NewCache[*wrapperspb.BoolValue]() + called.Add(1) + c.Set("foo", "0", wrapperspb.Bool(true), noTime) + c.Subscribe("foo", h) + time.Sleep(10 * time.Millisecond) + called.Wait() +} + +func TestExplicitAndImplicitSubscription(t *testing.T) { + c := diderot.NewCache[*wrapperspb.Int64Value]() + + names := utils.NewSet[string]() + for i := 0; i < 100; i++ { + names.Add(strconv.Itoa(i)) + } + + var lock sync.Mutex + initialSubscriptionNotificationsRemaining := maps.Clone(names) + allInitialNotificationsReceived := make(chan struct{}) + resourceCreationNotificationsRemaining := maps.Clone(names) + allCreationNotificationsReceived := make(chan struct{}) + + h := testutils.NewSubscriptionHandler( + func(name string, r *ads.Resource[*wrapperspb.Int64Value], metadata ads.SubscriptionMetadata) { + lock.Lock() + defer lock.Unlock() + + if r == nil { + initialSubscriptionNotificationsRemaining.Remove(name) + if len(initialSubscriptionNotificationsRemaining) == 0 { + close(allInitialNotificationsReceived) + } + } else { + if !resourceCreationNotificationsRemaining.Remove(name) { + t.Fatalf("Received double notifications for %q", name) + } + if len(resourceCreationNotificationsRemaining) == 0 { + close(allCreationNotificationsReceived) + } + } + }, + ) + + // implicit subscription + c.Subscribe(ads.WildcardSubscription, h) + + for name := range names { + // Explicit subscription + c.Subscribe(name, h) + } + + <-allInitialNotificationsReceived + + for name := range names { + c.Set(name, "0", wrapperspb.Int64(0), noTime) + } + + <-allCreationNotificationsReceived +} + +func TestSubscribeToGlobCollection(t *testing.T) { + c := diderot.NewCache[*wrapperspb.Int64Value]() + h := make(testutils.ChanSubscriptionHandler[*wrapperspb.Int64Value], 2) + const ( + fullGCNamePrefix = "xdstp:///google.protobuf.Int64Value/" + foo = "foo/" + collectionName = fullGCNamePrefix + foo + ads.WildcardSubscription + ) + c.Subscribe(collectionName, h) + c.IsSubscribedTo(collectionName, h) + h.WaitForDelete(t, collectionName) + + name := func(i int64) string { return fullGCNamePrefix + foo + fmt.Sprint(i) } + for i := int64(0); i < 10; i++ { + resource := c.Set(name(i), "0", wrapperspb.Int64(i), noTime) + h.WaitForUpdate(t, resource) + } + + for i := int64(0); i < 10; i++ { + c.Clear(name(i), noTime) + h.WaitForDelete(t, name(i)) + } + + h.WaitForDelete(t, collectionName) + + c.Unsubscribe(collectionName, h) + c.Set(name(10), "0", wrapperspb.Int64(10), noTime) + select { + case r := <-h: + t.Fatalf("Received update after unsubscription: %+v", r) + case <-time.After(100 * time.Millisecond): + } +} + +var _ ads.SubscriptionHandler[*wrapperspb.Int64Value] = (*fakeHandler)(nil) + +type fakeHandler struct { + lock sync.Mutex + resources map[string]*ads.Resource[*wrapperspb.Int64Value] + complete *sync.WaitGroup + allInitialized *sync.WaitGroup +} + +const ( + initializationVersion = "-1" + endResourceVersion = "0" +) + +func (f *fakeHandler) Notify(name string, r *ads.Resource[*wrapperspb.Int64Value], _ ads.SubscriptionMetadata) { + // Locking and writing to a map is a fairly representative Handler action and should provide acurate benchmarking + // results + f.lock.Lock() + defer f.lock.Unlock() + + if r != nil && r.Version == initializationVersion { + f.allInitialized.Done() + return + } + + f.resources[name] = r + if r != nil && r.Version == endResourceVersion { + f.complete.Done() + } +} + +type TB[T testing.TB] interface { + testing.TB + Run(name string, tb func(T)) bool +} + +// this method helps benchmark the cache implementation by doing the following: +// - It will spin up as many SubscriptionHandler instances as specified by the subscribers parameter to simulate that +// many clients being connected to the server. +// - It will then create as many entries as specified by the entries parameter, and subscribe each handler to each +// entry, simulating worst case scenario load. +// - Once all the handlers have subscribed to each entry, it spins up a goroutine for each entry and updates each +// entry as many times as specified in the updates parameter. This simulates heavy concurrent updates to the cache. +// - The test will end once each handler has counted down a sync.WaitGroup initialized with the expected number of +// updates. +func benchmarkCacheThroughput[T TB[T]](tb T, subscribers, entries int) { + cache := diderot.NewCache[*wrapperspb.Int64Value]() + + resources := make([]*ads.Resource[*wrapperspb.Int64Value], entries) + for i := range resources { + name := strconv.Itoa(i) + cache.Set(name, initializationVersion, new(wrapperspb.Int64Value), noTime) + resources[i] = ads.NewResource(name, "1", new(wrapperspb.Int64Value)) + } + + complete := new(sync.WaitGroup) + + allInitialized := new(sync.WaitGroup) + allInitialized.Add(subscribers * entries) + for i := 0; i < subscribers; i++ { + h := &fakeHandler{ + resources: make(map[string]*ads.Resource[*wrapperspb.Int64Value], entries), + complete: complete, + allInitialized: allInitialized, + } + // Use explicit subscriptions, otherwise the backing cache entry may get deleted, causing significant additional + // allocations. + for _, r := range resources { + cache.Subscribe(r.Name, h) + } + } + allInitialized.Wait() + + tb.Run(fmt.Sprintf("%5d subscribers/%5d entries", subscribers, entries), func(tb T) { + complete.Add(subscribers * entries) + + var n int + if b, ok := any(tb).(*testing.B); ok { + n = b.N + } else { + n = 100 + } + + for _, r := range resources { + go func(r *ads.Resource[*wrapperspb.Int64Value]) { + for i := 0; i < n/2; i++ { + cache.SetResource(r, noTime) + cache.Clear(r.Name, noTime) + } + // This final Set means a different *ads.RawResource is used as the final resource for each b.Run + // invocation. This is critical because the cache will ignore back-to-back updates with the same + // *ads.RawResource, meaning it may not notify the subscribers of the final version if it ignored + // intermediate updates. + cache.Set(r.Name, endResourceVersion, new(wrapperspb.Int64Value), noTime) + }(r) + } + + complete.Wait() + }) +} + +// Benchmark results as of 2023-04-7: +// +// make -B bin/BenchmarkCacheThroughput.profile BENCH_PKG=ads/cache +// Not opening profiles since OPEN_PROFILES is not set +// make[1]: Entering directory `/home/pchesnai/code/linkedin/indis-registry-observer' +// go test \ +// -o bin/BenchmarkCacheThroughput.profile \ +// -count 1 \ +// -benchmem \ +// -bench="^BenchmarkCacheThroughput$" \ +// -cpuprofile profiles/BenchmarkCacheThroughput.cpu \ +// -memprofile profiles/BenchmarkCacheThroughput.mem \ +// -blockprofile profiles/BenchmarkCacheThroughput.block \ +// -benchtime 10s \ +// -timeout 20m \ +// -run "^$" \ +// ./indis-registry-observer/src/ads/cache +// goos: linux +// goarch: amd64 +// cpu: Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +// BenchmarkCacheThroughput/____1_subscribers/____1_entries-8 195447372 61.30 ns/op 0 B/op 0 allocs/op +// BenchmarkCacheThroughput/____1_subscribers/___10_entries-8 63061165 182.8 ns/op 0 B/op 0 allocs/op +// BenchmarkCacheThroughput/____1_subscribers/__100_entries-8 6176012 1893 ns/op 0 B/op 0 allocs/op +// BenchmarkCacheThroughput/____1_subscribers/_1000_entries-8 629365 18486 ns/op 2 B/op 0 allocs/op +// BenchmarkCacheThroughput/___10_subscribers/____1_entries-8 190472704 61.84 ns/op 0 B/op 0 allocs/op +// BenchmarkCacheThroughput/___10_subscribers/___10_entries-8 61527584 191.1 ns/op 0 B/op 0 allocs/op +// BenchmarkCacheThroughput/___10_subscribers/__100_entries-8 6546378 1864 ns/op 0 B/op 0 allocs/op +// BenchmarkCacheThroughput/___10_subscribers/_1000_entries-8 620842 18315 ns/op 4 B/op 0 allocs/op +// BenchmarkCacheThroughput/__100_subscribers/____1_entries-8 194074869 62.25 ns/op 0 B/op 0 allocs/op +// BenchmarkCacheThroughput/__100_subscribers/___10_entries-8 62036193 181.7 ns/op 0 B/op 0 allocs/op +// BenchmarkCacheThroughput/__100_subscribers/__100_entries-8 6514137 1955 ns/op 0 B/op 0 allocs/op +// BenchmarkCacheThroughput/__100_subscribers/_1000_entries-8 567054 18684 ns/op 34 B/op 0 allocs/op +// BenchmarkCacheThroughput/__100_subscribers/10000_entries-8 56433 202390 ns/op 3333 B/op 5 allocs/op +// BenchmarkCacheThroughput/_1000_subscribers/____1_entries-8 195721614 61.06 ns/op 0 B/op 0 allocs/op +// BenchmarkCacheThroughput/_1000_subscribers/___10_entries-8 61467828 193.6 ns/op 0 B/op 0 allocs/op +// BenchmarkCacheThroughput/_1000_subscribers/__100_entries-8 6236072 1844 ns/op 2 B/op 0 allocs/op +// BenchmarkCacheThroughput/_1000_subscribers/_1000_entries-8 551036 20126 ns/op 375 B/op 0 allocs/op +// BenchmarkCacheThroughput/10000_subscribers/__100_entries-8 5894846 1994 ns/op 23 B/op 0 allocs/op +// PASS +// ok _/home/pchesnai/code/linkedin/indis-registry-observer/indis-registry-observer/src/ads/cache 297.029s +// make[1]: Leaving directory `/home/pchesnai/code/linkedin/indis-registry-observer' +// Command being timed: "make profile_cache" +// User time (seconds): 1663.37 +// System time (seconds): 5.77 +// Percent of CPU this job got: 560% +// Elapsed (wall clock) time (h:mm:ss or m:ss): 4:57.57 +// Average shared text size (kbytes): 0 +// Average unshared data size (kbytes): 0 +// Average stack size (kbytes): 0 +// Average total size (kbytes): 0 +// Maximum resident set size (kbytes): 278716 +// Average resident set size (kbytes): 0 +// Major (requiring I/O) page faults: 0 +// Minor (reclaiming a frame) page faults: 379417 +// Voluntary context switches: 239299 +// Involuntary context switches: 197562 +// Swaps: 0 +// File system inputs: 72 +// File system outputs: 312 +// Socket messages sent: 0 +// Socket messages received: 0 +// Signals delivered: 0 +// Page size (bytes): 4096 +// Exit status: 0 +// +// The final row likely represents the most common use case, where 10k clients are connected to a single machine and +// have subscribed to 100 resources each. In this setup, this cache can distribute updates to each listener in 2.4ms. +func BenchmarkCacheThroughput(b *testing.B) { + DisableTime(b) + increments := []int{1, 10, 100, 1000} + for _, subscribers := range increments { + for _, entries := range increments { + benchmarkCacheThroughput(b, subscribers, entries) + if subscribers == 100 && entries == 1000 { + benchmarkCacheThroughput(b, 100, 10_000) + } + } + } + benchmarkCacheThroughput(b, 10_000, 100) +} + +func TestCacheThroughput(t *testing.T) { + benchmarkCacheThroughput(t, 10, 10) +} + +func DisableTime(tb testing.TB) { + internal.SetTimeProvider(func() (t time.Time) { return t }) + tb.Cleanup(func() { + internal.SetTimeProvider(time.Now) + }) +} diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..5a26c3f --- /dev/null +++ b/doc.go @@ -0,0 +1,79 @@ +/* +Package diderot provides a set of utilities to implement an xDS control plan in go. Namely, it +provides two core elements: + 1. The [ADSServer], the implementation of both the SotW and Delta ADS stream variants. + 2. The [Cache], which is an efficient means to store, retrieve and subscribe to xDS resource definitions. + +# ADS Server and Resource Locator + +The [ADSServer] is an implementation of the xDS protocol's various features. It implements both the +Delta and state-of-the-world variants, but abstracts this away completely by only exposing a single +entry point: the [ResourceLocator]. When the server receives a request (be it Delta or SotW), it +will first check if the requested type is supported, whether it is an ACK (or a NACK), then invoke, +if necessary, the corresponding subscription methods on the ResourceLocator. The locator is simply +in charge of invoking Notify on the handler whenever the resource changes, and the server will relay +that resource update to the client using the corresponding response type. This makes it very easy to +implement an xDS control plane without needing to worry about the finer details of the xDS protocol. + +Most ResourceLocator implementations will likely be a series of [Cache] instances for the +corresponding supported types, which implements the semantics of Subscribe and Resubscribe out of +the box. However, as long as the semantics are respected, implementations may do as they please. For +example, a common pattern is listed in the [xDS spec]: + + For Listener and Cluster resource types, there is also a “wildcard” subscription, which is triggered + when subscribing to the special name *. In this case, the server should use site-specific business + logic to determine the full set of resources that the client is interested in, typically based on + the client’s node identification. + +Instead of invoking subscribing to a backing [Cache] with the wildcard subscription, the said +"business logic" can be implemented in the [ResourceLocator] and wildcard subscriptions can be +transformed into an explicit set of resources. + +# Cache + +This type is the core building block provided by this package. It is effectively a map from +resource name to [ads.Resource] definitions. It provides a way to subscribe to them in order to be +notified whenever they change. For example, the [ads.Endpoint] type (aka +"envoy.config.endpoint.v3.ClusterLoadAssignment") contains the set of IPs that back a specific +[ads.Cluster] ("envoy.config.cluster.v3.Cluster") and is the final step in the standard LDS -> RDS +-> CDS -> EDS Envoy flow. The Cache will store the Endpoint instances that back each cluster, and +Envoy will be able to subscribe to the [ads.Endpoint] resource by providing the correct name when +subscribing. See [diderot.Cache.Subscribe] for additional details on the subscription model. + +It is safe for concurrent use as its concurrency model is per-resource. This means different +goroutines can modify different resources concurrently, and goroutines attempting to modify the +same resource will be synchronized. + +# Cache Priority + +The cache supports a notion of "priority". Concretely, this feature is intended to be used when a +resource definition can come from multiple sources. For example, if resource definitions are being +migrated from one source to another, it would be sane to always use the new source if it is present, +otherwise fall back to the old source. This would be as opposed to simply picking whichever source +defined the resource most recently, as it would mean the resource definition cannot be relied upon +to be stable. [NewPrioritizedCache] returns a slice of instances of their respective types. The +instances all point to the same underlying cache, but at different priorities, where instances that +appear earlier in the slice have a higher priority than those that appear later. If a resource is +defined at priorities p1 and p2 where p1 is a higher priority than p2, subscribers will see the +version that was defined at p1. If the resource is cleared at p1, the cache will fall back to the +definition at p2. This means that a resource is only ever considered fully deleted if it is cleared +at all priority levels. The reason a slice of instances is returned rather than adding a priority +parameter to each function on [Cache] is to avoid complicated configuration or simple bugs where a +resource is being set at an unintended or invalid priority. Instead, the code path where a source is +populating the cache simply receives a reference to the cache and starts writing to it. If the +priority of a source changes in subsequent versions, it can be handled at initialization/startup +instead of requiring any actual code changes to the source itself. + +# xDS TP1 Support + +The notion of glob collections defined in the TP1 proposal is supported natively in the [Cache]. +This means that if resource names are [xdstp:// URNs], they will be automatically added to the +corresponding glob collection, if applicable. These resources are still available for subscription +by their full URN, but will also be available for subscription by subscribing to the parent glob +collection. More details available at [diderot.Cache.Subscribe], [ads.ParseGlobCollectionURL] and +[ads.ExtractGlobCollectionURLFromResourceURN]. + +[xDS spec]: https://www.envoyproxy.io/docs/envoy/latest/api-docs/xds_protocol#how-the-client-specifies-what-resources-to-return +[xdstp:// URNs]: https://github.com/cncf/xds/blob/main/proposals/TP1-xds-transport-next.md#uri-based-xds-resource-names +*/ +package diderot diff --git a/examples/quickstart/main.go b/examples/quickstart/main.go new file mode 100644 index 0000000..fb1091e --- /dev/null +++ b/examples/quickstart/main.go @@ -0,0 +1,117 @@ +package main + +import ( + "context" + "fmt" + "log" + "net" + "os" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + discovery "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" + "github.com/linkedin/diderot" + "github.com/linkedin/diderot/ads" + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" +) + +func main() { + lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", 8080)) + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + + var opts []grpc.ServerOption + // populate your ops + grpcServer := grpc.NewServer(opts...) + + // Use a very simple ResourceLocator that only supports a limited set of types (namely LDS -> RDS -> CDS -> EDS). + locator := NewSimpleResourceLocator(ListenerType, RouteType, ClusterType, EndpointType) + + go PopulateCaches(locator) + + hostname, _ := os.Hostname() + + adsServer := diderot.NewADSServer(locator, + // Send max 10k responses per second. + diderot.WithGlobalResponseRateLimit(10_000), + // Send max one response per type per client every 500ms, to not overload clients. + diderot.WithGranularResponseRateLimit(2), + // Process max 1k requests per second. + diderot.WithRequestRateLimit(1000), + diderot.WithControlPlane(&corev3.ControlPlane{Identifier: hostname}), + ) + discovery.RegisterAggregatedDiscoveryServiceServer(grpcServer, adsServer) + + grpcServer.Serve(lis) +} + +var ( + ListenerType = diderot.TypeOf[*ads.Listener]() + RouteType = diderot.TypeOf[*ads.Route]() + ClusterType = diderot.TypeOf[*ads.Cluster]() + EndpointType = diderot.TypeOf[*ads.Endpoint]() +) + +// SimpleResourceLocator is a bare-bones [diderot.ResourceLocator] that provides the bare minimum +// functionality. +type SimpleResourceLocator map[string]diderot.RawCache + +func (sl SimpleResourceLocator) IsTypeSupported(streamCtx context.Context, typeURL string) bool { + _, ok := sl[typeURL] + return ok +} + +func (sl SimpleResourceLocator) Subscribe( + streamCtx context.Context, + typeURL, resourceName string, + handler ads.RawSubscriptionHandler, +) (unsubscribe func()) { + c := sl[typeURL] + diderot.Subscribe(c, resourceName, handler) + return func() { + diderot.Unsubscribe(c, resourceName, handler) + } +} + +func (sl SimpleResourceLocator) Resubscribe( + streamCtx context.Context, + typeURL, resourceName string, + handler ads.RawSubscriptionHandler, +) { + diderot.Subscribe(sl[typeURL], resourceName, handler) +} + +// getCache extracts a typed [diderot.Cache] from the given [SimpleResourceLocator]. +func getCache[T proto.Message](sl SimpleResourceLocator) diderot.Cache[T] { + return sl[diderot.TypeOf[T]().URL()].(diderot.Cache[T]) +} + +func (sl SimpleResourceLocator) GetListenerCache() diderot.Cache[*ads.Listener] { + return getCache[*ads.Listener](sl) +} + +func (sl SimpleResourceLocator) GetRouteCache() diderot.Cache[*ads.Route] { + return getCache[*ads.Route](sl) +} + +func (sl SimpleResourceLocator) GetClusterCache() diderot.Cache[*ads.Cluster] { + return getCache[*ads.Cluster](sl) +} + +func (sl SimpleResourceLocator) GetEndpointCache() diderot.Cache[*ads.Endpoint] { + return getCache[*ads.Endpoint](sl) +} + +func NewSimpleResourceLocator(types ...diderot.Type) SimpleResourceLocator { + sl := make(SimpleResourceLocator) + for _, t := range types { + sl[t.URL()] = t.NewCache() + } + return sl +} + +func PopulateCaches(locator SimpleResourceLocator) { + // this is where the business logic of populating the caches should happen. For example, you can read + // the resource definitions from disk, listen to ZK, etc... +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..9ed78ba --- /dev/null +++ b/go.mod @@ -0,0 +1,36 @@ +module github.com/linkedin/diderot + +go 1.23.0 + +replace google.golang.org/grpc => github.com/PapaCharlie/grpc-go v0.0.0-20240701182450-c5ed95110455 + +require ( + github.com/envoyproxy/go-control-plane v0.12.0 + github.com/google/go-cmp v0.6.0 + github.com/stretchr/testify v1.9.0 + golang.org/x/time v0.5.0 + google.golang.org/genproto/googleapis/rpc v0.0.0-20240709173604-40e1e62336c5 + google.golang.org/grpc v1.65.0 + google.golang.org/protobuf v1.34.2 +) + +require ( + cel.dev/expr v0.15.0 // indirect + cloud.google.com/go/compute/metadata v0.3.0 // indirect + github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect + github.com/golang/protobuf v1.5.4 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.12.0 // indirect + golang.org/x/net v0.27.0 // indirect + golang.org/x/oauth2 v0.20.0 // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/sys v0.22.0 // indirect + golang.org/x/text v0.16.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240709173604-40e1e62336c5 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..55d1ad3 --- /dev/null +++ b/go.sum @@ -0,0 +1,58 @@ +cel.dev/expr v0.15.0 h1:O1jzfJCQBfL5BFoYktaxwIhuttaQPsVWerH9/EEKx0w= +cel.dev/expr v0.15.0/go.mod h1:TRSuuV7DlVCE/uwv5QbAiW/v8l5O8C4eEPHeu7gf7Sg= +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +github.com/PapaCharlie/grpc-go v0.0.0-20240701182450-c5ed95110455 h1:V+ucq9BdvIZnpcKmFsE9dFuJQbfneu0BIZ+EvTcfy4o= +github.com/PapaCharlie/grpc-go v0.0.0-20240701182450-c5ed95110455/go.mod h1:Ie8GXJUM8iOXNdBVTGsk658lCVKi+/3EJ4zBhiGj4Yk= +github.com/census-instrumentation/opencensus-proto v0.4.1 h1:iKLQ0xPNFxR/2hzXZMrBo8f1j86j5WHzznCCQxV/b8g= +github.com/census-instrumentation/opencensus-proto v0.4.1/go.mod h1:4T9NM4+4Vw91VeyqjLS6ao50K5bOcLKN6Q42XnYaRYw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b h1:ga8SEFjZ60pxLcmhnThWgvH2wg8376yUJmPhEH4H3kw= +github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/envoyproxy/go-control-plane v0.12.0 h1:4X+VP1GHd1Mhj6IB5mMeGbLCleqxjletLK6K0rbxyZI= +github.com/envoyproxy/go-control-plane v0.12.0/go.mod h1:ZBTaoJ23lqITozF0M6G4/IragXCQKCnYbmlmtHvwRG0= +github.com/envoyproxy/protoc-gen-validate v1.0.4 h1:gVPz/FMfvh57HdSJQyvBtF00j8JU4zdyUgIUNhlgg0A= +github.com/envoyproxy/protoc-gen-validate v1.0.4/go.mod h1:qys6tmnRsYrQqIhm2bvKZH4Blx/1gTIZ2UKVY1M+Yew= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= +golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/oauth2 v0.20.0 h1:4mQdhULixXKP1rwYBW0vAijoXnkTG0BLCDRzfe1idMo= +golang.org/x/oauth2 v0.20.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +google.golang.org/genproto/googleapis/api v0.0.0-20240709173604-40e1e62336c5 h1:a/Z0jgw03aJ2rQnp5PlPpznJqJft0HyvyrcUcxgzPwY= +google.golang.org/genproto/googleapis/api v0.0.0-20240709173604-40e1e62336c5/go.mod h1:mw8MG/Qz5wfgYr6VqVCiZcHe/GJEfI+oGGDCohaVgB0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240709173604-40e1e62336c5 h1:SbSDUWW1PAO24TNpLdeheoYPd7kllICcLU52x6eD4kQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240709173604-40e1e62336c5/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/cache/glob_collection.go b/internal/cache/glob_collection.go new file mode 100644 index 0000000..86db12d --- /dev/null +++ b/internal/cache/glob_collection.go @@ -0,0 +1,119 @@ +package internal + +import ( + "sync" + "time" + + "github.com/linkedin/diderot/ads" + "github.com/linkedin/diderot/internal/utils" + "google.golang.org/protobuf/proto" +) + +// A globCollection is used to track all the resources in the collection. +type globCollection[T proto.Message] struct { + // The URL that corresponds to this collection, represented as the raw string rather than a + // GlobCollectionURL to avoid repeated redundant calls to GlobCollectionURL.String. + url string + + // Protects subscribers and values. + subscribersAndValuesLock sync.RWMutex + // The current subscribers to this collection. + subscribers SubscriberSet[T] + // The set of values in the collection, used by new subscribers to subscribe to all values. + values utils.Set[*WatchableValue[T]] + + // Protects nonNilValueNames + nonNilValueNamesLock sync.Mutex + // The set of all non-nil resource names in this collection. Used to track whether a collection is + // empty. Note that a collection can be empty even if values is non-empty since values that are + // explicitly subscribed to are kept in the collection/cache to track the subscription in case the + // value returns. + nonNilValueNames utils.Set[string] +} + +func (g *globCollection[T]) hasNoValuesOrSubscribersNoLock() bool { + return len(g.values) == 0 && g.subscribers.Size() == 0 +} + +// hasNoValuesOrSubscribers returns true if the collection is empty and has no subscribers. +func (g *globCollection[T]) hasNoValuesOrSubscribers() bool { + g.subscribersAndValuesLock.RLock() + defer g.subscribersAndValuesLock.RUnlock() + + return g.hasNoValuesOrSubscribersNoLock() +} + +// isSubscribed checks if the given handler is already subscribed to the collection. +func (g *globCollection[T]) isSubscribed(handler ads.SubscriptionHandler[T]) bool { + g.subscribersAndValuesLock.Lock() + defer g.subscribersAndValuesLock.Unlock() + + return g.subscribers.IsSubscribed(handler) +} + +// subscribe adds the given handler as a subscriber to the collection, and iterates through all the +// values in the collection, notifying the handler for each value. If the collection is empty, the +// handler will be notified that the resource is deleted. +func (g *globCollection[T]) subscribe(handler ads.SubscriptionHandler[T]) { + g.subscribersAndValuesLock.Lock() + defer g.subscribersAndValuesLock.Unlock() + + subscribedAt, version := g.subscribers.Subscribe(handler) + + if len(g.nonNilValueNames) == 0 { + handler.Notify(g.url, nil, ads.SubscriptionMetadata{ + SubscribedAt: subscribedAt, + ModifiedAt: time.Time{}, + CachedAt: time.Time{}, + }) + } else { + for v := range g.values { + v.NotifyHandlerAfterSubscription(handler, GlobSubscription, subscribedAt, version) + } + } +} + +// unsubscribe unsubscribes the given handler from the collection. Returns true if the collection has +// no subscribers and is empty. +func (g *globCollection[T]) unsubscribe(handler ads.SubscriptionHandler[T]) bool { + g.subscribersAndValuesLock.Lock() + defer g.subscribersAndValuesLock.Unlock() + + g.subscribers.Unsubscribe(handler) + return g.hasNoValuesOrSubscribersNoLock() +} + +// resourceSet notifies the collection that the given resource has been created. +func (g *globCollection[T]) resourceSet(name string) { + g.nonNilValueNamesLock.Lock() + defer g.nonNilValueNamesLock.Unlock() + + g.nonNilValueNames.Add(name) +} + +// resourceCleared notifies the collection that the given resource has been cleared. If there are no +// remaining non-nil values in the collection (or no values at all), the subscribers are all notified +// that the collection has been deleted. +func (g *globCollection[T]) resourceCleared(name string) { + g.nonNilValueNamesLock.Lock() + defer g.nonNilValueNamesLock.Unlock() + + g.nonNilValueNames.Remove(name) + + if len(g.nonNilValueNames) == 0 { + g.subscribersAndValuesLock.Lock() + defer g.subscribersAndValuesLock.Unlock() + + deletedAt := time.Now() + + subscribers, _ := g.subscribers.Iterator() + for handler, subscribedAt := range subscribers { + handler.Notify(g.url, nil, ads.SubscriptionMetadata{ + SubscribedAt: subscribedAt, + ModifiedAt: deletedAt, + CachedAt: deletedAt, + GlobCollectionURL: g.url, + }) + } + } +} diff --git a/internal/cache/glob_collections_map.go b/internal/cache/glob_collections_map.go new file mode 100644 index 0000000..979ee8f --- /dev/null +++ b/internal/cache/glob_collections_map.go @@ -0,0 +1,111 @@ +package internal + +import ( + "log/slog" + + "github.com/linkedin/diderot/ads" + "github.com/linkedin/diderot/internal/utils" + "google.golang.org/protobuf/proto" +) + +// GlobCollectionsMap used to map individual GlobCollectionURL to their corresponding globCollection. +// This uses a ResourceMap under the hood because it has similar semantics to cache entries: +// 1. A globCollection is created lazily, either when an entry for that collection is created, or a +// subscription to that collection is made. +// 2. A globCollection is only deleted once all subscribers have unsubscribed and the collection is +// empty. Crucially, a collection can be empty but will remain in the cache as long as some +// subscribers remain subscribed. +type GlobCollectionsMap[T proto.Message] struct { + collections ResourceMap[ads.GlobCollectionURL, *globCollection[T]] +} + +// createOrModifyCollection gets or creates the globCollection for the given GlobCollectionURL, and +// executes the given function on it. +func (gcm *GlobCollectionsMap[T]) createOrModifyCollection( + gcURL ads.GlobCollectionURL, + f func(gcURL ads.GlobCollectionURL, collection *globCollection[T]), +) { + gcm.collections.Compute( + gcURL, + func(gcURL ads.GlobCollectionURL) *globCollection[T] { + gc := &globCollection[T]{ + url: gcURL.String(), + values: make(utils.Set[*WatchableValue[T]]), + nonNilValueNames: make(utils.Set[string]), + } + slog.Debug("Created collection", "url", gcURL) + return gc + }, + f, + ) +} + +// PutValueInCollection creates the glob collection if it was not already created, and puts the given +// value in it. +func (gcm *GlobCollectionsMap[T]) PutValueInCollection(gcURL ads.GlobCollectionURL, value *WatchableValue[T]) { + gcm.createOrModifyCollection(gcURL, func(gcURL ads.GlobCollectionURL, collection *globCollection[T]) { + collection.subscribersAndValuesLock.Lock() + defer collection.subscribersAndValuesLock.Unlock() + + value.globCollection = collection + collection.values.Add(value) + value.SubscriberSets[GlobSubscription] = &collection.subscribers + }) +} + +// RemoveValueFromCollection removes the given value from the collection. If the collection becomes +// empty as a result, it is removed from the map. +func (gcm *GlobCollectionsMap[T]) RemoveValueFromCollection(gcURL ads.GlobCollectionURL, value *WatchableValue[T]) { + var isEmpty bool + gcm.collections.ComputeIfPresent(gcURL, func(gcURL ads.GlobCollectionURL, collection *globCollection[T]) { + collection.subscribersAndValuesLock.Lock() + defer collection.subscribersAndValuesLock.Unlock() + + collection.values.Remove(value) + + isEmpty = collection.hasNoValuesOrSubscribersNoLock() + }) + if isEmpty { + gcm.deleteCollectionIfEmpty(gcURL) + } +} + +// Subscribe creates or gets the corresponding collection for the given URL using +// createOrModifyCollection, then invokes globCollection.subscribe with the given handler. +func (gcm *GlobCollectionsMap[T]) Subscribe(gcURL ads.GlobCollectionURL, handler ads.SubscriptionHandler[T]) { + gcm.createOrModifyCollection(gcURL, func(_ ads.GlobCollectionURL, collection *globCollection[T]) { + collection.subscribe(handler) + }) +} + +// Unsubscribe invokes globCollection.unsubscribe on the collection for the given URL, if it exists. +// If, as a result, the collection becomes empty, it invokes deleteCollectionIfEmpty. +func (gcm *GlobCollectionsMap[T]) Unsubscribe(gcURL ads.GlobCollectionURL, handler ads.SubscriptionHandler[T]) { + var isEmpty bool + gcm.collections.ComputeIfPresent(gcURL, func(_ ads.GlobCollectionURL, collection *globCollection[T]) { + isEmpty = collection.unsubscribe(handler) + }) + if isEmpty { + gcm.deleteCollectionIfEmpty(gcURL) + } +} + +// deleteCollectionIfEmpty attempts to completely remove the collection from the map, if and only if +// there are no more subscribers and the collection is empty. +func (gcm *GlobCollectionsMap[T]) deleteCollectionIfEmpty(gcURL ads.GlobCollectionURL) { + gcm.collections.DeleteIf(gcURL, func(_ ads.GlobCollectionURL, collection *globCollection[T]) bool { + empty := collection.hasNoValuesOrSubscribers() + if empty { + slog.Debug("Deleting collection", "url", gcURL) + } + return empty + }) +} + +// IsSubscribed checks if the given handler is subscribed to the collection. +func (gcm *GlobCollectionsMap[T]) IsSubscribed(gcURL ads.GlobCollectionURL, handler ads.SubscriptionHandler[T]) (subscribed bool) { + gcm.collections.ComputeIfPresent(gcURL, func(_ ads.GlobCollectionURL, collection *globCollection[T]) { + subscribed = collection.isSubscribed(handler) + }) + return subscribed +} diff --git a/internal/cache/resource_map.go b/internal/cache/resource_map.go new file mode 100644 index 0000000..2f491a4 --- /dev/null +++ b/internal/cache/resource_map.go @@ -0,0 +1,153 @@ +package internal + +import ( + "sync" +) + +// The resourceMapEntry is the value type in a ResourceMap. Access to an entry is guarded by lock and isDeleted. +// +// The reason ResourceMap does not directly insert the value type into the backing sync.Map is because of the +// following scenario: +// Routine 1 wants to delete an entry because it is not needed anymore, while routine 2 wants to execute an operation on +// the entry. Routine 2 loads the value from the map and runs its operation. In the meantime, routine 1 simply deletes +// the entry from the map, completely voiding routine 2's work. To prevent this, the additional resourceMapEntry layer +// provides a sync.RWMutex and a flag indicating the entry has been deleted from the map (isDeleted). All operations +// acquire the read lock and check isDeleted, bailing if the entry has been deleted between when it was loaded from the +// map and the read lock was acquired. Conversely, the write lock on the entry must be held during deletion, and +// isDelete must be set once the entry has been deleted and while still holding the write lock. This way, only one +// routine can successfully delete the entry from the map, but it must wait for all other in-flight read operations to +// do so (see ResourceMap.DeleteIf for additional info). +type resourceMapEntry[T any] struct { + // The write lock must be held while inserting or removing the entry from the resources map. The read lock must be + // held while interacting with value. Before interacting with the value, isDeleted must be checked. + lock sync.RWMutex + // isDeleted must always be checked before accessing value. If true, this entry should be entirely discarded. + // Guarded by lock. + isDeleted bool + // the actual value + value T +} + +// ResourceMap is a typed extension of a sync.Map which allows for fine-grained control over the creations and deletions +// of entries. It is meant to imitate Java's compute/computeIfPresent/computeIfAbsent methods on ConcurrentHashMap as +// sync.Map does not natively provide these constructs. It deliberately does not expose bare Get or Put methods as its +// concurrency model is based on the assumption that access to the backing values must be strictly synchronized. +// Instead, all operations should be executed through the various compute methods. +type ResourceMap[K comparable, T any] struct { + syncMap sync.Map +} + +// ComputeIfPresent executes the given compute function if the entry is present in the map. There can be multiple +// executions of ComputeIfPresent in flight at the same time for the same entry. +func (m *ResourceMap[K, T]) ComputeIfPresent(key K, compute func(key K, value T)) bool { + eAny, _ := m.syncMap.Load(key) + e, ok := eAny.(*resourceMapEntry[T]) + if !ok { + // The entry didn't exist in the map, do nothing. + return false + } + + e.lock.RLock() + defer e.lock.RUnlock() + if e.isDeleted { + // In between loading the entry and acquiring the read lock, the entry was deleted, do nothing. + return false + } + + compute(key, e.value) + return true +} + +// Compute ensures that an entry for the given name exists in the map before executing the given compute function. If +// the entry was already in the map, it has the same semantics as ComputeIfPresent. Otherwise, it uses the given +// newValue constructor to create a new value before executing the given compute method, and no ComputeIfPresent will +// run until both the newValue constructor and the compute function complete. This is to prevent +// subsequent ComputeIfPresent operations from reading a partially initialized value. +func (m *ResourceMap[K, T]) Compute( + key K, + newValue func(key K) T, + compute func(key K, value T), +) { + // Do an initial check to see if the entry is already present in the map, and if so, apply the function. + if m.ComputeIfPresent(key, compute) { + return + } + + // Otherwise, attempt to create the entry + e := new(resourceMapEntry[T]) + // Inserting the entry while already holding the write lock means another routine that loads this entry from the + // map will be forced to wait until it is fully initialized before reading its value. + e.lock.Lock() + defer e.lock.Unlock() + + // This loop guarantees the creation of the entry by repeatedly attempting to insert the value in the map, + // guaranteeing that the entry is eventually inserted. + for { + // Attempt to insert the new resourceMapEntry in the map + eAny, didNotStore := m.syncMap.LoadOrStore(key, e) + if !didNotStore { + // Entry successfully inserted in the map, the rest of the function can initialize it safely as it is + // currently holding the write lock. + break + } + + // Unhappy path: between the original ComputeIfPresent check and now, another routine created the entry. + e := eAny.(*resourceMapEntry[T]) + e.lock.RLock() + if !e.isDeleted { + // The entry was already present in the map, execute the compute function safely since the read lock is + // currently held. + compute(key, e.value) + // Nothing left to be done, exit the function + e.lock.RUnlock() + return + } + + // Very unhappy path: the entry was deleted between this invocation reading it from the map and + // acquiring the read lock! Attempt to insert the new entry again by continuing the loop. Note that this + // condition is the reason this is in a loop in the first place. It should also be noted that it is + // extremely unlikely for this occur, and in most instances this branch will never be reached. + e.lock.RUnlock() + } + + // Initialize the value while holding the write lock, otherwise a call to ComputeIfPresent could read e.value as + // nil. + e.value = newValue(key) + // Technically, at this point, the write lock does not need to be held anymore, only the read lock since e.value + // is initialized. It is however impossible to downgrade the lock write to read without actually releasing it. In + // between releasing the write lock and acquiring the read lock, it is possible for the entry to be deleted! Hence, + // the write lock is kept. + compute(key, e.value) +} + +// DeleteIf loads the entry from the map if it still exists, then executes the given condition function with the value. +// If the condition returns true, the entry is deleted from the map, otherwise nothing happens. It is guaranteed that +// the condition function will only be executed once any in-flight ComputeIfPresent operations for that entry complete. +// Conversely, once the in-flight operations complete, no new ComputeIfPresent operations will be started for that entry +// until the condition has been checked. If the entry was deleted, any ComputeIfPresent operations queued for that entry +// while the condition was being checked will be abandoned. No two executions of DeleteIf can execute in parallel for +// the same entry. +func (m *ResourceMap[K, T]) DeleteIf(key K, condition func(key K, value T) bool) { + eAny, ok := m.syncMap.Load(key) + if !ok { + return + } + + e := eAny.(*resourceMapEntry[T]) + e.lock.Lock() + defer e.lock.Unlock() + if e.isDeleted { + return + } + if condition(key, e.value) { + m.syncMap.Delete(key) + } +} + +// Keys iterates through all the keys in the map. It does not expose the actual value as all +// operations on the values should be executed through Compute, ComputeIfPresent and DeleteIf. +func (m *ResourceMap[K, T]) Keys(f func(key K) bool) { + m.syncMap.Range(func(key, value any) bool { + return f(key.(K)) + }) +} diff --git a/internal/cache/subscriber_set.go b/internal/cache/subscriber_set.go new file mode 100644 index 0000000..417eb37 --- /dev/null +++ b/internal/cache/subscriber_set.go @@ -0,0 +1,130 @@ +package internal + +import ( + "iter" + "sync" + "sync/atomic" + "time" + + "github.com/linkedin/diderot/ads" + "google.golang.org/protobuf/proto" +) + +// SubscriberSetVersion is a monotonically increasing counter that tracks how many times subscribers +// have been added to a given SubscriberSet. This means a subscriber can check whether they are in a +// SubscriberSet by storing the version returned by SubscriberSet.Subscribe and comparing it against +// the version returned by SubscriberSet.Iterator. +type SubscriberSetVersion uint64 + +// SubscriberSet is a concurrency-safe data structure that stores a set of unique subscribers. It is +// specifically designed to support wildcard and glob subscriptions such that they can be shared by +// multiple watchableValues instead of requiring each WatchableValue to store each subscriber. After +// subscribing to a given value, the SubscriptionHandler is supposed to be notified of the current +// value immediately, which usually simply means reading WatchableValue.currentValue and notifying +// the handler. However, it is possible that the notification loop for the WatchableValue is already +// running, and it could result in a double notification. To avoid this, this data structure +// introduces a notion of versioning. This way, the notification loop can record which version it is +// about to iterate over (in WatchableValue.lastSeenSubscriberSetVersions) such that subscribers can +// determine whether the loop will notify them and avoid the double notification. This is done by +// recording the version returned by SubscriberSet.Subscribe and checking whether it's equal to or +// smaller than the version in WatchableValue.lastSeenSubscriberSetVersions. +// +// The implementation uses a sync.Map to store and iterate over the subscribers. In this case it's +// impossible to use a normal map since the subscriber set will be iterated over frequently. However, +// sync.Map provides no guarantees about what happens if the map is modified while another goroutine +// is iterating over the entries. Specifically, if an entry is added during the iteration, the +// iterator may or may not actually yield the new entry, which means the iterator may yield an entry +// that was added _after_ Iterator was invoked, violating the Iterator contract that it will only +// yield entries that were added before. To get around this, the returned iterator simply records the +// version at which it was initially created, and drops entries that have a greater version, making +// it always consistent. +type SubscriberSet[T proto.Message] struct { + // Protects entry creation in the set. + lock sync.Mutex + // Maps SubscriptionHandler instances to the subscriber instance containing the metadata. + subscribers sync.Map // Real type: map[SubscriptionHandler[T]]*subscriber + // The current subscriber set version. + version SubscriberSetVersion + // Stores the current number of subscribers. + size atomic.Int64 +} + +type subscriber struct { + subscribedAt time.Time + id SubscriberSetVersion +} + +// IsSubscribed checks whether the given handler is subscribed to this set. +func (m *SubscriberSet[T]) IsSubscribed(handler ads.SubscriptionHandler[T]) bool { + if m == nil { + return false + } + + _, ok := m.subscribers.Load(handler) + return ok +} + +// Subscribe registers the given SubscriptionHandler as a subscriber and returns the time and version +// at which the subscription was processed. The returned version can be compared against the version +// returned by Iterator to check whether the given handler is present in the iterator. +func (m *SubscriberSet[T]) Subscribe(handler ads.SubscriptionHandler[T]) (subscribedAt time.Time, id SubscriberSetVersion) { + m.lock.Lock() + defer m.lock.Unlock() + + m.version++ + s := &subscriber{ + subscribedAt: timeProvider(), + id: m.version, + } + _, loaded := m.subscribers.Swap(handler, s) + if !loaded { + m.size.Add(1) + } + + return s.subscribedAt, s.id +} + +// Unsubscribe removes the given handler from the set, and returns whether the set is now empty as a +// result of this unsubscription. +func (m *SubscriberSet[T]) Unsubscribe(handler ads.SubscriptionHandler[T]) (empty bool) { + _, loaded := m.subscribers.LoadAndDelete(handler) + if !loaded { + return m.size.Load() == 0 + } + + return m.size.Add(-1) == 0 +} + +// Size returns the number of subscribers in the set. For convenience, returns 0 if the receiver is +// nil. +func (m *SubscriberSet[T]) Size() int { + if m == nil { + return 0 + } + return int(m.size.Load()) +} + +type SubscriberSetIterator[T proto.Message] iter.Seq2[ads.SubscriptionHandler[T], time.Time] + +// Iterator returns an iterator over the SubscriberSet. The returned associated version can be used +// by subscribers to check whether they are present in the iterator. For convenience, returns an +// empty iterator and invalid version if the receiver is nil. +func (m *SubscriberSet[T]) Iterator() (SubscriberSetIterator[T], SubscriberSetVersion) { + if m == nil { + return func(yield func(ads.SubscriptionHandler[T], time.Time) bool) {}, 0 + } + + m.lock.Lock() + version := m.version + m.lock.Unlock() + + return func(yield func(ads.SubscriptionHandler[T], time.Time) bool) { + m.subscribers.Range(func(key, value any) bool { + s := value.(*subscriber) + if s.id > version { + return true + } + return yield(key.(ads.SubscriptionHandler[T]), s.subscribedAt) + }) + }, version +} diff --git a/internal/cache/subscriber_set_test.go b/internal/cache/subscriber_set_test.go new file mode 100644 index 0000000..421efec --- /dev/null +++ b/internal/cache/subscriber_set_test.go @@ -0,0 +1,97 @@ +package internal + +import ( + "testing" + "time" + + "github.com/linkedin/diderot/ads" + "github.com/stretchr/testify/require" + . "google.golang.org/protobuf/types/known/timestamppb" +) + +type noopHandler byte + +func (*noopHandler) Notify(string, *ads.Resource[*Timestamp], ads.SubscriptionMetadata) {} + +type iterateArgs struct { + handler ads.SubscriptionHandler[*Timestamp] + subscribedAt time.Time +} + +func checkIterate(t *testing.T, m *SubscriberSet[*Timestamp], expectedV SubscriberSetVersion, expectedArgs ...iterateArgs) { + require.Equal(t, m.Size(), len(expectedArgs)) + seq, v := m.Iterator() + require.Equal(t, expectedV, v) + + var actualArgs []iterateArgs + + for handler, subscribedAt := range seq { + actualArgs = append(actualArgs, iterateArgs{ + handler: handler, + subscribedAt: subscribedAt, + }) + } + require.ElementsMatch(t, expectedArgs, actualArgs) +} + +func TestSubscriberMap(t *testing.T) { + s := new(SubscriberSet[*Timestamp]) + checkIterate(t, s, 0) + + h1 := new(noopHandler) + sAt1, v := s.Subscribe(h1) + require.Equal(t, SubscriberSetVersion(1), v) + require.True(t, s.IsSubscribed(h1)) + + checkIterate(t, s, 1, + iterateArgs{ + handler: h1, + subscribedAt: sAt1, + }, + ) + + h2 := new(noopHandler) + sAt2, v := s.Subscribe(h2) + require.NotEqual(t, sAt1, sAt2) + require.Equal(t, SubscriberSetVersion(2), v) + require.True(t, s.IsSubscribed(h2)) + + checkIterate(t, s, 2, + iterateArgs{ + handler: h1, + subscribedAt: sAt1, + }, + iterateArgs{ + handler: h2, + subscribedAt: sAt2, + }, + ) + + sAt3, v := s.Subscribe(h1) + require.NotEqual(t, sAt1, sAt3) + require.Equal(t, SubscriberSetVersion(3), v) + require.True(t, s.IsSubscribed(h1)) + + checkIterate(t, s, 3, + iterateArgs{ + handler: h2, + subscribedAt: sAt2, + }, + iterateArgs{ + handler: h1, + subscribedAt: sAt3, + }, + ) + + s.Unsubscribe(h1) + require.False(t, s.IsSubscribed(h1)) + checkIterate(t, s, 3, + iterateArgs{ + handler: h2, + subscribedAt: sAt2, + }) + + s.Unsubscribe(h2) + require.False(t, s.IsSubscribed(h2)) + checkIterate(t, s, 3) +} diff --git a/internal/cache/subscription_type.go b/internal/cache/subscription_type.go new file mode 100644 index 0000000..67414dd --- /dev/null +++ b/internal/cache/subscription_type.go @@ -0,0 +1,23 @@ +package internal + +// subscriptionType describes the ways a client can subscribe to a resource. +type subscriptionType byte + +// The following subscriptionType constants define the ways a client can subscribe to a resource. See +// RawCache.Subscribe for additional details. +const ( + // An ExplicitSubscription means the client subscribed to a resource by explicit providing its name. + ExplicitSubscription = subscriptionType(iota) + // A GlobSubscription means the client subscribed to a resource by specifying its parent glob + // collection URL, implicitly subscribing it to all the resources that are part of the collection. + GlobSubscription + // A WildcardSubscription means the client subscribed to a resource by specifying the wildcard + // (ads.WildcardSubscription), implicitly subscribing it to all resources in the cache. + WildcardSubscription + + subscriptionTypes = iota +) + +func (t subscriptionType) isImplicit() bool { + return t != ExplicitSubscription +} diff --git a/internal/cache/subscription_type_test.go b/internal/cache/subscription_type_test.go new file mode 100644 index 0000000..4ab40d2 --- /dev/null +++ b/internal/cache/subscription_type_test.go @@ -0,0 +1,13 @@ +package internal + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSubscriptionType(t *testing.T) { + require.False(t, ExplicitSubscription.isImplicit()) + require.True(t, GlobSubscription.isImplicit()) + require.True(t, WildcardSubscription.isImplicit()) +} diff --git a/internal/cache/watchable_value.go b/internal/cache/watchable_value.go new file mode 100644 index 0000000..b0b313e --- /dev/null +++ b/internal/cache/watchable_value.go @@ -0,0 +1,412 @@ +package internal + +import ( + "sync" + "sync/atomic" + "time" + + "github.com/linkedin/diderot/ads" + "google.golang.org/protobuf/proto" +) + +type Priority int + +type valueWithMetadata[T proto.Message] struct { + resource *ads.Resource[T] + modifiedAt time.Time + cachedAt time.Time + idx Priority + globCollectionURL string +} + +func (v *valueWithMetadata[T]) subscriptionMetadata(subscribedAt time.Time) ads.SubscriptionMetadata { + return ads.SubscriptionMetadata{ + SubscribedAt: subscribedAt, + ModifiedAt: v.modifiedAt, + CachedAt: v.cachedAt, + Priority: int(v.idx), + GlobCollectionURL: v.globCollectionURL, + } +} + +type WatchableValue[T proto.Message] struct { + name string + globCollection *globCollection[T] + + // lock protects all fields of this struct except name and globCollection since they are not modified + // after initialization. currentValue is not strictly protected by lock since it can be read without + // holding lock, but in practice lock is always held when updating currentValue. + lock sync.Mutex + // valuesFromDifferentPrioritySources is represented as a slice to represent the backing prioritized + // values. A lower index means a higher priority. + valuesFromDifferentPrioritySources []*ads.Resource[T] + // currentIndex is always a valid index into valuesFromDifferentPrioritySources and represents the + // currently highest priority (lowest index) non-nil value in valuesFromDifferentPrioritySources. If + // valuesFromDifferentPrioritySources[currentIndex] is nil, the value has been cleared at all + // priorities. See NewPrioritizedCache for additional details on how priority. + currentIndex Priority + // modifiedAt contains the timestamp provided by the most recent call to set or clear. + modifiedAt time.Time + // cachedAt contains the timestamp of the most recent call to set or clear. + cachedAt time.Time + // loopStatus stores the current state of the loop, see startNotificationLoop for additional + // details. + loopStatus loopStatus + // loopWg is incremented every time the loop starts, and decremented whenever it completes. + loopWg sync.WaitGroup + // SubscriberSets is holds all the async.SubscriberSet instances relevant to this WatchableValue. + SubscriberSets [subscriptionTypes]*SubscriberSet[T] + // lastSeenSubscriberSetVersions stores the SubscriberSetVersion of the most-recently iterated + // SubscriberSet. When subscribing to this WatchableValue, subscriber goroutines should check this + // value + lastSeenSubscriberSetVersions [subscriptionTypes]SubscriberSetVersion + + // currentValue always contains the value of valuesFromDifferentPrioritySources[currentIndex] so that + // the current value can be read without holding valueLock. + currentValue atomic.Pointer[ads.Resource[T]] +} + +func NewValue[T proto.Message]( + name string, + prioritySlots int, +) *WatchableValue[T] { + return &WatchableValue[T]{ + name: name, + valuesFromDifferentPrioritySources: make([]*ads.Resource[T], prioritySlots), + currentIndex: Priority(prioritySlots - 1), + SubscriberSets: [subscriptionTypes]*SubscriberSet[T]{new(SubscriberSet[T])}, + } +} + +func (v *WatchableValue[T]) IsSubscribed(handler ads.SubscriptionHandler[T]) bool { + return v.SubscriberSets[ExplicitSubscription].IsSubscribed(handler) +} + +func (v *WatchableValue[T]) Subscribe(handler ads.SubscriptionHandler[T]) { + subscribedAt, version := v.SubscriberSets[ExplicitSubscription].Subscribe(handler) + v.NotifyHandlerAfterSubscription(handler, ExplicitSubscription, subscribedAt, version) +} + +func (v *WatchableValue[T]) Unsubscribe(handler ads.SubscriptionHandler[T]) (empty bool) { + return v.SubscriberSets[ExplicitSubscription].Unsubscribe(handler) +} + +func (v *WatchableValue[T]) Read() *ads.Resource[T] { + return v.currentValue.Load() +} + +func (v *WatchableValue[T]) readWithMetadataNoLock() valueWithMetadata[T] { + var gcURL string + if v.globCollection != nil { + gcURL = v.globCollection.url + } + return valueWithMetadata[T]{ + resource: v.valuesFromDifferentPrioritySources[v.currentIndex], + modifiedAt: v.modifiedAt, + cachedAt: v.cachedAt, + idx: v.currentIndex, + globCollectionURL: gcURL, + } +} + +func (v *WatchableValue[T]) Clear(p Priority, clearedAt time.Time) (isFullClear bool) { + v.lock.Lock() + defer v.lock.Unlock() + + defer func() { + isFullClear = v.valuesFromDifferentPrioritySources[v.currentIndex] == nil + }() + + // Nothing to be done since the value is already nil + if v.valuesFromDifferentPrioritySources[p] == nil { + return + } + + v.valuesFromDifferentPrioritySources[p] = nil + + // If any value other than the current value is being cleared, no need to notify the subscribers. The + // invariants maintained by clear and set guarantee that v.currentIndex points to the highest + // priority value. In other words, there exists no index i such that i < v.currentIndex where + // v.valuesFromDifferentPrioritySources[i] isn't nil. + if p != v.currentIndex { + return + } + + // If the current value is being cleared, then subscribers need to be notified of the next highest priority non-nil + // value. If no non-nil values remain, the subscribers need to be notified of the deletion. + for ; int(v.currentIndex) != len(v.valuesFromDifferentPrioritySources); v.currentIndex++ { + if v.valuesFromDifferentPrioritySources[v.currentIndex] != nil { + break + } + } + + if int(v.currentIndex) == len(v.valuesFromDifferentPrioritySources) { + // Didn't find any non-nil values, leave v.currentIndex at the lowest priority. The invariant maintained + // here is that v.currentIndex should always be a valid index into v.valuesFromDifferentPrioritySources. + v.currentIndex = Priority(len(v.valuesFromDifferentPrioritySources) - 1) + } + + v.notify(clearedAt) + return +} + +// Set updates resource to the given version and value and notifies all subscribers of the new value. It is invalid to +// invoke this method with a nil resource. +func (v *WatchableValue[T]) Set(p Priority, r *ads.Resource[T], modifiedAt time.Time) { + v.lock.Lock() + defer v.lock.Unlock() + + v.valuesFromDifferentPrioritySources[p] = r + // Ignore updates of a lower priority than the current value + if p > v.currentIndex { + return + } + + v.currentIndex = p + + v.notify(modifiedAt) +} + +// SetTimeProvider can be used to provide an alternative time provider, which is important when +// benchmarking the cache. +func SetTimeProvider(now func() time.Time) { + timeProvider = now +} + +var timeProvider = time.Now + +func (v *WatchableValue[T]) notify(modifiedAt time.Time) { + v.cachedAt = timeProvider() + v.modifiedAt = modifiedAt + v.currentValue.Store(v.valuesFromDifferentPrioritySources[v.currentIndex]) + + v.startNotificationLoop() +} + +type loopStatus byte + +const ( + // notRunning means the loop has either never run, or has completed running through one cycle after + // the value has received an update. Subscribers should check the corresponding + // WatchableValue.lastSeenSubscriberSetVersions to see if they have already been notified. + notRunning = loopStatus(iota) + // initialized means the goroutine has been started but has not yet loaded the subscriber map and updated + // WatchableValue.lastSeenSubscriberSetVersions. + initialized + // running means the goroutine has started and has loaded the subscriber map. + running +) + +// NotifyHandlerAfterSubscription should be invoked by subscribers after subscribing to the +// corresponding SubscriberSet. This function is guaranteed to only return once the handler has been +// notified of the current value, since the xDS protocol spec explicitly states that an explicit +// subscription to an entry must always be respected by sending the current value: +// https://www.envoyproxy.io/docs/envoy/latest/api-docs/xds_protocol#subscribing-to-resources +// +// A resource_names_subscribe field may contain resource names that the server believes the client is +// already subscribed to, and furthermore has the most recent versions of. However, the server must +// still provide those resources in the response; due to implementation details hidden from the +// server, the client may have "forgotten" those resources despite apparently remaining subscribed. +func (v *WatchableValue[T]) NotifyHandlerAfterSubscription( + handler ads.SubscriptionHandler[T], + subType subscriptionType, + subscribedAt time.Time, + version SubscriberSetVersion, +) { + v.lock.Lock() + value := v.readWithMetadataNoLock() + + switch { + // If a WatchableValue is initially nil (i.e. it's being preserved in the map to keep track of + // explicit subscriptions), avoid notifying implicit subscribers, as the resource doesn't actually + // exist. + case subType.isImplicit() && value.resource == nil: + v.lock.Unlock() + // If the loop isn't currently running and the subscriber map version is equal or greater than the + // subscriber version, the loop has already completed and has already notified the subscriber, so + // exit immediately. + case v.loopStatus != running && v.lastSeenSubscriberSetVersions[subType] >= version: + v.lock.Unlock() + // Since lastSeenSubscriberSetVersions is updated by the notification loop goroutine while holding the lock, + // it can be used to track whether the loop goroutine has already picked up the new + // SubscriptionHandler and will notify it as part of its ongoing execution. In this case, the + // subscriber goroutine simply waits for the loop to complete to guarantee that the handler has been + // notified. This is as opposed to notifying the handler directly even though the loop is running, + // potentially resulting in a double notification. + case v.loopStatus == initialized || (v.loopStatus == running && v.lastSeenSubscriberSetVersions[subType] >= version): + v.lock.Unlock() + v.loopWg.Wait() + default: + handler.Notify(v.name, value.resource, value.subscriptionMetadata(subscribedAt)) + v.lock.Unlock() + } +} + +// startNotificationLoop spawns a goroutine that will notify all the subscribers to this entry of the +// current value. If the current loopStatus is not notRunning (i.e. the goroutine from a previous +// invocation is still running), immediately returns and does nothing. Must be invoked while holding +// lock. If the value is updated while the subscribers are being notified, it will bail on updating +// the rest of the subscribers and start from the top again. This way the routine can be reused by +// back-to-back updates instead of creating a new one every time. The loopStatus will be updated to +// reflect the current status of the goroutine, i.e. it will initialized while the goroutine is being +// spun up, then running when the goroutine has loaded the SubscriberSets. +func (v *WatchableValue[T]) startNotificationLoop() { + if v.loopStatus != notRunning { + return + } + + v.loopStatus = initialized + v.loopWg.Add(1) + + go func() { + defer v.loopWg.Done() + for { + v.lock.Lock() + value := v.readWithMetadataNoLock() + + var subscriberIterators [subscriptionTypes]SubscriberSetIterator[T] + for i := range subscriptionTypes { + subscriberIterators[i], v.lastSeenSubscriberSetVersions[i] = v.SubscriberSets[i].Iterator() + } + + v.loopStatus = running + v.lock.Unlock() + + if v.globCollection != nil && value.resource != nil { + // To ensure proper ordering of notifications for subscribers, it's important to notify the + // collection of created resources _before_ looping through the subscribers, and to notify it of + // deleted resources _after_. + v.globCollection.resourceSet(v.name) + } + + if !v.notifySubscribers(value, subscriberIterators) { + // If notifySubscribers returns false, the value changed during the loop. Immediately reload the + // value and try again, reusing the goroutine. + continue + } + + if v.globCollection != nil && value.resource == nil { + v.globCollection.resourceCleared(v.name) + } + + v.lock.Lock() + done := v.valuesFromDifferentPrioritySources[v.currentIndex] == value.resource + if done { + // At this point, the most recent value was successfully pushed to all subscribers since it has not + // changed from when it was initially read at the top of the loop. Since the lock is currently held, + // setting loopRunning to false will signal to the next invocation of startNotificationLoop that the + // loop routine is not running. + v.loopStatus = notRunning + } + // Otherwise, if done isn't true then the value changed in between notifying the subscribers and + // grabbing the lock. In this case the loop will restart, reusing the goroutine. + v.lock.Unlock() + + if done { + return + } + } + }() +} + +// notifySubscribers is invoked by WatchableValue.startNotificationLoop with the desired update. It +// returns true if all the subscribers were notified, or false if the loop exited early because the +// value changed during the iteration. +func (v *WatchableValue[T]) notifySubscribers( + value valueWithMetadata[T], + iterators [subscriptionTypes]SubscriberSetIterator[T], +) bool { + for handler, subscribedAt := range v.iterateSubscribers(iterators) { + // If the value has changed while looping over the subscribers, stop the loop and try again with the + // latest value. This avoids doing duplicate/wasted work since subscribers only care about the latest + // version of the value and don't mind missing intermediate values. + if v.Read() != value.resource { + return false + } + + handler.Notify(v.name, value.resource, value.subscriptionMetadata(subscribedAt)) + } + return true +} + +// iterateSubscribers returns a SubscriberSetIterator that iterates over all the unique subscribers +// to this value. This means if a subscriber is subscribed to this value both with a wildcard and an +// explicit subscription, the returned sequence will only yield the handler once. +// +// Author's note: there is a known race condition in this function if a subscriber is present in more +// than one iterator. The loop iterates through the given iterators in order, and looks back at +// previous iterators to check whether a given SubscriptionHandler has already been notified with +// SubscriberSet.IsSubscribed. Suppose the following sequence: +// 1. A SubscriptionHandler "foo" is present both as an explicit and wildcard subscriber to this +// value. +// 2. iterateSubscribers completes the iteration over the explicit subscribers, yielding foo once. +// 3. foo unsubscribes from the explicit SubscriberSet as iterateSubscribers begins iterating over the +// wildcard SubscriberSet. iterateSubscribers encounters foo in the wildcard SubscriberSet and checks +// whether foo is present in the explicit SubscriberSet (which it has already iterated over), sees +// that foo is not explicitly subscribed, and yields foo again since it believes it is the first time +// foo is encountered. +// +// Such a scenario can lead to foo being notified of the same update twice, which is undesirable but +// acceptable. Note however that the sequence of subscriptions "subscribe explicit" -> "subscribe +// wildcard" -> "unsubscribe explicit" does not actually exist in the standard Envoy flow, meaning +// this edge case is extremely unlikely. This is because iterateSubscribers always iterates through +// the subscriber sets in the same order: ExplicitSubscription -> GlobSubscription -> +// WildcardSubscription. The [known flow] of "subscribe wildcard" -> "subscribe explicit" -> +// "unsubscribe wildcard" is therefore completely unaffected by this and behaves as expected where +// foo is only yielded once. +// +// Possible alternatives to this iteration approach were considered: +// +// # Reverse checking order +// +// Instead of checking whether a SubscriptionHandler was already yielded by looking at *previous* +// SubscriberSets, iterateSubscribers could look at the *upcoming* iterators and determine whether a +// handler will eventually be yielded by a subsequent iterator, and skip yielding the handler from +// this iterator. While this addresses the edge case above, it symmetrically opens a significantly +// worse edge case in the reverse scenario: +// 1. A SubscriptionHandler "foo" is present both as an explicit and wildcard subscriber to this +// value. +// 2. iterateSubscribers completes the iteration over the explicit subscribers, skipping foo since it +// sees that it is currently present in the wildcard SubscriberSet. +// 3. foo unsubscribes from the wildcard SubscriberSet as iterateSubscribers begins iterating over the +// wildcard SubscriberSet. foo is never encountered again, meaning it will never be notified of the +// resource update. +// +// This edge case means updates will altogether be dropped rather than delivered twice, which is +// unacceptable. Not only that, unlike the current iteration approach, the [known flow] is likely to +// trigger this edge, making this approach completely non-viable. +// +// # Running set +// +// The obvious way to iterate through the union of subscribers would be to simply use a set to keep +// track of all the SubscriptionHandlers that have already been yielded. This however defeats the +// very point of the SubscriberSet data structure in the first place, which *avoids* storing state +// for each subscriber. It would create memory pressure as each invocation of iterateSubscribers +// would create a new map which would get promptly discarded once the invocation exits. The memory +// pressure could be addressed by leveraging a sync.Pool to allow reusing these sets, but it would +// only introduce complexity and overhead in an already complex loop. As such, this was ultimately +// deemed not necessary as the current solution will only duplicate notifications +// unknown/non-standard subscription flows. +// +// [known flow]: https://www.envoyproxy.io/docs/envoy/latest/api-docs/xds_protocol#how-the-client-specifies-what-resources-to-return +func (v *WatchableValue[T]) iterateSubscribers( + iterators [subscriptionTypes]SubscriberSetIterator[T], +) SubscriberSetIterator[T] { + return func(yield func(ads.SubscriptionHandler[T], time.Time) bool) { + subscriberLoop: + for i, subscribers := range iterators { + for handler, subscribedAt := range subscribers { + // If this handler was already yielded once from a previous iterator, skip it. + for j := i - 1; j >= 0; j-- { + if v.SubscriberSets[j].IsSubscribed(handler) { + continue subscriberLoop + } + } + + if !yield(handler, subscribedAt) { + break + } + } + } + } +} diff --git a/internal/server/handlers.go b/internal/server/handlers.go new file mode 100644 index 0000000..9ff836f --- /dev/null +++ b/internal/server/handlers.go @@ -0,0 +1,364 @@ +package internal + +import ( + "context" + "sync" + "time" + + "github.com/linkedin/diderot/ads" + "github.com/linkedin/diderot/internal/utils" + serverstats "github.com/linkedin/diderot/stats/server" + "golang.org/x/time/rate" + "google.golang.org/protobuf/proto" +) + +// BatchSubscriptionHandler is an extension of the SubscriptionHandler interface in the root package +// which allows a handler to be notified that a batch of calls to Notify is about to be received +// (StartNotificationBatch). The batch of notifications should not be sent to the client until all +// notifications for that batch have been received (EndNotificationBatch). Start and End will never +// be invoked out of order, i.e. there will never be a call to EndNotificationBatch without a call to +// StartNotificationBatch immediately preceding it. However, SubscriptionHandler.Notify can be +// invoked at any point. +type BatchSubscriptionHandler interface { + StartNotificationBatch() + ads.RawSubscriptionHandler + EndNotificationBatch() +} + +type entry struct { + Resource *ads.RawResource + metadata ads.SubscriptionMetadata +} + +func newHandler( + ctx context.Context, + granularLimiter handlerLimiter, + globalLimiter handlerLimiter, + statsHandler serverstats.Handler, + ignoreDeletes bool, + send func(entries map[string]entry) error, +) *handler { + h := &handler{ + granularLimiter: granularLimiter, + globalLimiter: globalLimiter, + statsHandler: statsHandler, + ctx: ctx, + ignoreDeletes: ignoreDeletes, + send: send, + entries: map[string]entry{}, + immediateNotificationReceived: newNotifyOnceChan(), + notificationReceived: newNotifyOnceChan(), + } + go h.loop() + return h +} + +func newNotifyOnceChan() notifyOnceChan { + return make(chan struct{}, 1) +} + +// notifyOnceChan is a resettable chan that only receives a notification once. It is exclusively +// meant to be used by handler. All methods should be invoked while holding the corresponding +// handler.lock. +type notifyOnceChan chan struct{} + +// notify notifies the channel using a non-blocking send +func (ch notifyOnceChan) notify() { + select { + case ch <- struct{}{}: + default: + } +} + +// reset ensures the channel has no pending notifications in case they were never read (this can +// happen if a notification comes in after the granular rate limit clears but before the +// corresponding handler.lock is acquired). +func (ch notifyOnceChan) reset() { + select { + // clear the channel if it has a pending notification. This is required since + // immediateNotificationReceived can be notified _after_ the granular limit clears. If it isn't + // cleared during the reset, the loop will read from it and incorrectly detect an immediate + // notification. + case <-ch: + // otherwise return immediately if the channel is empty + default: + } +} + +// handler implements the BatchSubscriptionHandler interface using a backing map to aggregate updates +// as they come in, and flushing them out, according to when the limiter permits it. +type handler struct { + granularLimiter handlerLimiter + globalLimiter handlerLimiter + statsHandler serverstats.Handler + lock sync.Mutex + ctx context.Context + ignoreDeletes bool + send func(entries map[string]entry) error + + entries map[string]entry + + // The following notifyOnceChan instances are the signaling mechanism between loop and Notify. Calls + // to Notify will first invoke notifyOnceChan.notify on immediateNotificationReceived based on the + // contents of the subscription metadata, then call notify on notificationReceived. loop waits on the + // channel that backs notificationReceived to be signaled and once the first notification is + // received, waits for the global rate limit to clear. This allows updates to keep accumulating. It + // then checks whether immediateNotificationReceived has been signaled, and if so skips the granular + // rate limiter. Otherwise, it either waits for the granular rate limit to clear, or + // immediateNotificationReceived to be signaled, whichever comes first. Only then does it invoke + // swapResourceMaps which resets notificationReceived, immediateNotificationReceived and entries to a + // state where they can receive more notifications while, in the background, it invokes send with all + // accumulated entries up to this point. Once send completes, it returns to waiting on + // notificationReceived. All operations involving these channels will exit early if ctx is cancelled, + // terminating the loop. + immediateNotificationReceived notifyOnceChan + notificationReceived notifyOnceChan + + // If batchStarted is true, Notify will not notify notificationReceived. This allows the batch to + // complete before the response is sent, minimizing the number of responses. + batchStarted bool +} + +// swapResourceMaps sets entries to the given map and returns the original value of h.resources and +// resets immediateNotificationReceived and notificationReceived. +func (h *handler) swapResourceMaps(entries map[string]entry) map[string]entry { + h.lock.Lock() + defer h.lock.Unlock() + entries, h.entries = h.entries, entries + h.notificationReceived.reset() + h.immediateNotificationReceived.reset() + return entries +} + +func (h *handler) loop() { + entries := map[string]entry{} + + for { + select { + case <-h.ctx.Done(): + return + case <-h.notificationReceived: + // Always wait for the global rate limiter to clear + if waitForGlobalLimiter(h.ctx, h.globalLimiter, h.statsHandler) != nil { + return + } + // Wait for the granular rate limiter + if h.waitForGranularLimiterOrShortCircuit() != nil { + return + } + } + + entries = h.swapResourceMaps(entries) + + if err := h.send(entries); err != nil { + return + } + + // TODO: have an admin UI that shows which clients are lagging the most + clear(entries) + } +} + +func waitForGlobalLimiter( + ctx context.Context, + globalLimiter handlerLimiter, + statsHandler serverstats.Handler, +) error { + if statsHandler != nil { + start := time.Now() + defer func() { + statsHandler.HandleServerEvent(ctx, &serverstats.TimeInGlobalRateLimiter{Duration: time.Since(start)}) + }() + } + + reservation, cancel := globalLimiter.reserve() + defer cancel() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-reservation: + return nil + } +} + +// waitForGranularLimiterOrShortCircuit will acquire a reservation from granularLimiter and wait on +// it, but will short circuit the reservation if an immediate notification is received (or if ctx is +// canceled). +func (h *handler) waitForGranularLimiterOrShortCircuit() error { + reservation, cancel := h.granularLimiter.reserve() + defer cancel() + + for { + select { + case <-h.ctx.Done(): + return h.ctx.Err() + case <-h.immediateNotificationReceived: + // If an immediate notification is received, immediately return instead of waiting for the granular + // limit. Without this, a bootstrapping client may be forced to wait for the initial versions of the + // resources it is interested in. The purpose of the rate limiter is to avoid overwhelming the + // client, however if the client is healthy enough to request new resources then those resources + // should be sent without delay. Do note, however, that the responses will still always be rate + // limited by the global limiter. + return nil + case <-reservation: + // Otherwise, wait for the granular rate limit to clear. + return nil + } + } +} + +func (h *handler) Notify(name string, r *ads.RawResource, metadata ads.SubscriptionMetadata) { + h.lock.Lock() + defer h.lock.Unlock() + + if h.statsHandler != nil { + h.statsHandler.HandleServerEvent(h.ctx, &serverstats.ResourceQueued{ + ResourceName: name, + Resource: r, + Metadata: metadata, + ResourceExists: !metadata.CachedAt.IsZero(), + }) + } + + if r == nil && h.ignoreDeletes { + return + } + + h.entries[name] = entry{ + Resource: r, + metadata: metadata, + } + + if r != nil && metadata.GlobCollectionURL != "" { + // When a glob collection is empty, it is signaled to the client with a corresponding deletion of + // that collection's name. For example, if a collection Foo/* becomes empty (or the client subscribed + // to a collection that does not exist), it will receive a deletion notification for Foo/*. There is + // an edge case in the following scenario: suppose a collection currently has some resource Foo/A in + // it. Upon subscribing, the handler will be notified that the resource exists. Foo/A is then + // removed, so the handler receives a notification that Foo/A is removed, and because Foo/* is empty + // it also receives a corresponding notification. But, soon after, resource Foo/B is created, + // reviving Foo/* and the handler receives the corresponding notification for Foo/B. At this point, + // if the response were to be sent as-is, it would contain both the creation of Foo/B and the + // deletion of Foo/*. Depending on the order in which the client processes the response's contents, + // it may ignore Foo/B altogether. To avoid this, always clear out the deletion of Foo/* when a + // notification for the creation of an entry within Foo/* is received. + delete(h.entries, metadata.GlobCollectionURL) + } + + if !h.batchStarted { + h.notificationReceived.notify() + } +} + +func (h *handler) ResourceMarshalError(name string, resource proto.Message, err error) { + if h.statsHandler != nil { + h.statsHandler.HandleServerEvent(h.ctx, &serverstats.ResourceMarshalError{ + ResourceName: name, + Resource: resource, + Err: err, + }) + } +} + +func (h *handler) StartNotificationBatch() { + h.lock.Lock() + defer h.lock.Unlock() + + h.batchStarted = true +} + +func (h *handler) EndNotificationBatch() { + h.lock.Lock() + defer h.lock.Unlock() + + h.batchStarted = false + + if len(h.entries) > 0 { + h.immediateNotificationReceived.notify() + h.notificationReceived.notify() + } +} + +func NewSotWHandler( + ctx context.Context, + granularLimiter *rate.Limiter, + globalLimiter *rate.Limiter, + statsHandler serverstats.Handler, + typeURL string, + send func(res *ads.SotWDiscoveryResponse) error, +) BatchSubscriptionHandler { + return newSotWHandler( + ctx, + (*rateLimiterWrapper)(granularLimiter), + (*rateLimiterWrapper)(globalLimiter), + statsHandler, + typeURL, + send, + ) +} + +func newSotWHandler( + ctx context.Context, + granularLimiter handlerLimiter, + globalLimiter handlerLimiter, + statsHandler serverstats.Handler, + typeUrl string, + send func(res *ads.SotWDiscoveryResponse) error, +) *handler { + isPseudoDeltaSotW := utils.IsPseudoDeltaSotW(typeUrl) + var looper func(resources map[string]entry) error + if isPseudoDeltaSotW { + looper = func(entries map[string]entry) error { + versions := map[string]string{} + + for name, e := range entries { + versions[name] = e.Resource.Version + } + + res := &ads.SotWDiscoveryResponse{ + TypeUrl: typeUrl, + Nonce: utils.NewNonce(), + } + for _, e := range entries { + res.Resources = append(res.Resources, e.Resource.Resource) + } + res.VersionInfo = utils.MapToProto(versions) + return send(res) + } + } else { + allResources := map[string]entry{} + versions := map[string]string{} + + looper = func(resources map[string]entry) error { + for name, r := range resources { + if r.Resource != nil { + allResources[name] = r + versions[name] = r.Resource.Version + } else { + delete(allResources, name) + delete(versions, name) + } + } + + res := &ads.SotWDiscoveryResponse{ + TypeUrl: typeUrl, + Nonce: utils.NewNonce(), + } + for _, r := range allResources { + res.Resources = append(res.Resources, r.Resource.Resource) + } + res.VersionInfo = utils.MapToProto(versions) + return send(res) + } + } + + return newHandler( + ctx, + granularLimiter, + globalLimiter, + statsHandler, + isPseudoDeltaSotW, + looper, + ) +} diff --git a/internal/server/handlers_bench_test.go b/internal/server/handlers_bench_test.go new file mode 100644 index 0000000..95c70eb --- /dev/null +++ b/internal/server/handlers_bench_test.go @@ -0,0 +1,69 @@ +package internal + +import ( + "fmt" + "strconv" + "sync" + "testing" + + "github.com/linkedin/diderot/ads" + "github.com/linkedin/diderot/testutils" + "google.golang.org/protobuf/types/known/anypb" +) + +func benchmarkHandlers(tb testing.TB, count, subscriptions int) { + valueNames := make([]string, subscriptions) + for i := range valueNames { + valueNames[i] = strconv.Itoa(i) + } + + ctx := testutils.Context(tb) + + var finished sync.WaitGroup + finished.Add(subscriptions) + const finalVersion = "done" + h := newHandler( + ctx, + NoopLimiter{}, + NoopLimiter{}, + new(customStatsHandler), + false, + func(resources map[string]entry) error { + for _, r := range resources { + if r.Resource.Version == finalVersion { + finished.Done() + } + } + return nil + }, + ) + + for _, name := range valueNames { + go func(name string) { + resource := testutils.MustMarshal(tb, ads.NewResource(name, "0", new(anypb.Any))) + for i := 0; i < count-1; i++ { + h.Notify(name, resource, ads.SubscriptionMetadata{}) + } + h.Notify( + name, + &ads.RawResource{Name: name, Version: finalVersion, Resource: resource.Resource}, + ads.SubscriptionMetadata{}, + ) + }(name) + } + finished.Wait() +} + +var increments = []int{1, 10, 100, 1000, 10_000} + +func BenchmarkHandlers(b *testing.B) { + for _, subscriptions := range increments { + b.Run(fmt.Sprintf("%5d subs", subscriptions), func(b *testing.B) { + benchmarkHandlers(b, b.N, subscriptions) + }) + } +} + +func TestHandlers(t *testing.T) { + benchmarkHandlers(t, 1000, 1000) +} diff --git a/internal/server/handlers_delta.go b/internal/server/handlers_delta.go new file mode 100644 index 0000000..21d708e --- /dev/null +++ b/internal/server/handlers_delta.go @@ -0,0 +1,204 @@ +package internal + +import ( + "cmp" + "context" + "log/slog" + "slices" + + "github.com/linkedin/diderot/ads" + "github.com/linkedin/diderot/internal/utils" + serverstats "github.com/linkedin/diderot/stats/server" + "golang.org/x/time/rate" + "google.golang.org/protobuf/proto" +) + +func NewDeltaHandler( + ctx context.Context, + granularLimiter *rate.Limiter, + globalLimiter *rate.Limiter, + statsHandler serverstats.Handler, + maxChunkSize int, + typeUrl string, + send func(res *ads.DeltaDiscoveryResponse) error, +) BatchSubscriptionHandler { + return newDeltaHandler( + ctx, + (*rateLimiterWrapper)(granularLimiter), + (*rateLimiterWrapper)(globalLimiter), + statsHandler, + maxChunkSize, + typeUrl, + send, + ) +} + +func newDeltaHandler( + ctx context.Context, + granularLimiter handlerLimiter, + globalLimiter handlerLimiter, + statsHandler serverstats.Handler, + maxChunkSize int, + typeURL string, + send func(res *ads.DeltaDiscoveryResponse) error, +) *handler { + ds := &deltaSender{ + typeURL: typeURL, + maxChunkSize: maxChunkSize, + statsHandler: statsHandler, + minChunkSize: initialChunkSize(typeURL), + } + + return newHandler( + ctx, + granularLimiter, + globalLimiter, + statsHandler, + false, + func(entries map[string]entry) error { + for i, chunk := range ds.chunk(entries) { + if i > 0 { + // Respect the global limiter in between chunks + err := waitForGlobalLimiter(ctx, globalLimiter, statsHandler) + if err != nil { + return err + } + } + err := send(chunk) + if err != nil { + return err + } + } + return nil + }, + ) +} + +type queuedResourceUpdate struct { + Name string + Size int +} + +type deltaSender struct { + ctx context.Context + typeURL string + statsHandler serverstats.Handler + // The maximum size (in bytes) that a chunk can be. This is determined by the client as anything + // larger than this size will cause the message to be dropped. + maxChunkSize int + + // This slice is reused by chunk. It contains the updates about to be sent, sorted by their size over + // the wire. + queuedUpdates []queuedResourceUpdate + // The minimum size an encoded chunk will serialize to, in bytes. Used to check whether a given + // update can _ever_ be sent, and as the initial size of a chunk. Note that this value only depends + // on utils.NonceLength and the length of typeURL. + minChunkSize int +} + +func (ds *deltaSender) chunk(resourceUpdates map[string]entry) (chunks []*ads.DeltaDiscoveryResponse) { + defer func() { + clear(ds.queuedUpdates) + ds.queuedUpdates = ds.queuedUpdates[:0] + }() + for name, e := range resourceUpdates { + ds.queuedUpdates = append(ds.queuedUpdates, queuedResourceUpdate{ + Name: name, + Size: encodedUpdateSize(name, e.Resource), + }) + } + // Sort the updates in descending order + slices.SortFunc(ds.queuedUpdates, func(a, b queuedResourceUpdate) int { + return -cmp.Compare(a.Size, b.Size) + }) + + // This nested loop builds the fewest possible chunks it can from the given resourceUpdates map. It + // implements an approximation of the bin-packing algorithm called next-fit-decreasing bin-packing + // https://en.wikipedia.org/wiki/Next-fit-decreasing_bin_packing + idx := 0 + for idx < len(ds.queuedUpdates) { + // This chunk will hold all the updates for this loop iteration + chunk := ds.newChunk() + chunkSize := proto.Size(chunk) + + for ; idx < len(ds.queuedUpdates); idx++ { + update := ds.queuedUpdates[idx] + r := resourceUpdates[update.Name].Resource + + if ds.maxChunkSize > 0 { + if ds.minChunkSize+update.Size > ds.maxChunkSize { + // This condition only occurs if the update can never be sent, i.e. it is too large and will + // always be dropped by the client. It should therefore be skipped altogether, but flagged + // accordingly. + if ds.statsHandler != nil { + ds.statsHandler.HandleServerEvent(ds.ctx, &serverstats.ResourceOverMaxSize{ + Resource: r, + ResourceSize: update.Size, + MaxResourceSize: ds.maxChunkSize, + }) + } + slog.ErrorContext( + ds.ctx, + "Cannot send resource update because it is larger than configured max delta response size", + "maxDeltaResponseSize", ds.maxChunkSize, + "name", update.Name, + "updateSize", update.Size, + "resource", r, + ) + continue + } + if chunkSize+update.Size > ds.maxChunkSize { + // This update it too large to be sent along with the current chunk, skip it for now and + // attempt it in the next chunk. + break + } + } + + if r != nil { + chunk.Resources = append(chunk.Resources, r) + } else { + chunk.RemovedResources = append(chunk.RemovedResources, update.Name) + } + // Add the resource since it is small enough to be added to the chunk + chunkSize += update.Size + } + + chunks = append(chunks, chunk) + } + + if len(chunks) > 1 { + slog.WarnContext( + ds.ctx, + "Response exceeded max response size, sent in chunks", + "chunks", len(chunks), + "typeURL", ds.typeURL, + "updates", len(ds.queuedUpdates), + ) + } + + return chunks +} + +func (ds *deltaSender) newChunk() *ads.DeltaDiscoveryResponse { + return &ads.DeltaDiscoveryResponse{ + TypeUrl: ds.typeURL, + Nonce: utils.NewNonce(), + } +} + +const protobufSliceOverhead = 2 + +func initialChunkSize(typeUrl string) int { + return protobufSliceOverhead + len(typeUrl) + protobufSliceOverhead + utils.NonceLength +} + +// encodedUpdateSize returns the amount of bytes it takes to encode the given update in an *ads.DeltaDiscoveryResponse. +func encodedUpdateSize(name string, r *ads.RawResource) int { + resourceSize := protobufSliceOverhead + if r != nil { + resourceSize += proto.Size(r) + } else { + resourceSize += len(name) + } + return resourceSize +} diff --git a/internal/server/handlers_delta_test.go b/internal/server/handlers_delta_test.go new file mode 100644 index 0000000..c19e482 --- /dev/null +++ b/internal/server/handlers_delta_test.go @@ -0,0 +1,178 @@ +package internal + +import ( + "context" + "strings" + "sync/atomic" + "testing" + + "github.com/linkedin/diderot/ads" + "github.com/linkedin/diderot/internal/utils" + serverstats "github.com/linkedin/diderot/stats/server" + "github.com/linkedin/diderot/testutils" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +func TestDeltaHandler(t *testing.T) { + l := NewTestHandlerLimiter() + + typeURL := utils.GetTypeURL[*wrapperspb.BoolValue]() + var lastRes *ads.DeltaDiscoveryResponse + h := newDeltaHandler( + testutils.Context(t), + NoopLimiter{}, + l, + new(customStatsHandler), + 0, + typeURL, + func(res *ads.DeltaDiscoveryResponse) error { + defer l.Done() + lastRes = res + return nil + }, + ) + + const foo, bar = "foo", "bar" + h.Notify(foo, nil, ignoredMetadata) + r := new(ads.RawResource) + h.Notify(bar, r, ignoredMetadata) + + l.Release() + + require.Equal(t, typeURL, lastRes.TypeUrl) + require.Len(t, lastRes.Resources, 1) + require.Equal(t, r, lastRes.Resources[0]) + require.Equal(t, []string{foo}, lastRes.RemovedResources) +} + +func TestEncodedUpdateSize(t *testing.T) { + foo := testutils.MustMarshal(t, ads.NewResource("foo", "42", new(wrapperspb.Int64Value))) + notFoo := testutils.MustMarshal(t, ads.NewResource("notFoo", "27", new(wrapperspb.Int64Value))) + + checkSize := func(t *testing.T, msg proto.Message, size int) { + require.Equal(t, size, proto.Size(msg)) + data, err := proto.Marshal(msg) + require.NoError(t, err) + require.Len(t, data, size) + } + + ds := &deltaSender{typeURL: utils.GetTypeURL[*wrapperspb.StringValue]()} + + t.Run("add", func(t *testing.T) { + res := ds.newChunk() + responseSize := proto.Size(res) + + res.Resources = append(res.Resources, foo) + responseSize += encodedUpdateSize(foo.Name, foo) + checkSize(t, res, responseSize) + + res.Resources = append(res.Resources, notFoo) + responseSize += encodedUpdateSize(notFoo.Name, notFoo) + checkSize(t, res, responseSize) + }) + t.Run("remove", func(t *testing.T) { + res := ds.newChunk() + responseSize := proto.Size(res) + + res.RemovedResources = append(res.RemovedResources, foo.Name) + responseSize += encodedUpdateSize(foo.Name, nil) + checkSize(t, res, responseSize) + + res.RemovedResources = append(res.RemovedResources, notFoo.Name) + responseSize += encodedUpdateSize(notFoo.Name, nil) + checkSize(t, res, responseSize) + }) + t.Run("add and remove", func(t *testing.T) { + res := ds.newChunk() + responseSize := proto.Size(res) + + res.Resources = append(res.Resources, foo) + responseSize += encodedUpdateSize(foo.Name, foo) + checkSize(t, res, responseSize) + + res.RemovedResources = append(res.RemovedResources, notFoo.Name) + responseSize += encodedUpdateSize(notFoo.Name, nil) + checkSize(t, res, responseSize) + }) +} + +func TestInitialChunkSize(t *testing.T) { + typeURL := utils.GetTypeURL[*wrapperspb.StringValue]() + require.Equal(t, proto.Size(&ads.DeltaDiscoveryResponse{ + TypeUrl: typeURL, + Nonce: utils.NewNonce(), + }), initialChunkSize(typeURL)) +} + +func TestDeltaHandlerChunking(t *testing.T) { + foo := testutils.MustMarshal(t, ads.NewResource("foo", "0", wrapperspb.String("foo"))) + bar := testutils.MustMarshal(t, ads.NewResource("bar", "0", wrapperspb.String("bar"))) + require.Equal(t, proto.Size(foo), proto.Size(bar)) + resourceSize := proto.Size(foo) + + typeURL := utils.GetTypeURL[*wrapperspb.StringValue]() + statsHandler := new(customStatsHandler) + ds := &deltaSender{ + typeURL: typeURL, + statsHandler: statsHandler, + maxChunkSize: initialChunkSize(typeURL) + protobufSliceOverhead + resourceSize, + minChunkSize: initialChunkSize(typeURL), + } + + sentResponses := ds.chunk(map[string]entry{ + foo.Name: {Resource: foo}, + bar.Name: {Resource: bar}, + }) + + require.Equal(t, len(sentResponses[0].Resources), 1) + require.Equal(t, len(sentResponses[1].Resources), 1) + response0 := sentResponses[0].Resources[0] + response1 := sentResponses[1].Resources[0] + + if response0.Name == foo.Name { + testutils.ProtoEquals(t, foo, response0) + testutils.ProtoEquals(t, bar, response1) + } else { + testutils.ProtoEquals(t, bar, response0) + testutils.ProtoEquals(t, foo, response1) + } + + // Delete resources whose names are the same size as the resources to trip the chunker with the same conditions + name1 := strings.Repeat("1", resourceSize) + name2 := strings.Repeat("2", resourceSize) + sentResponses = ds.chunk(map[string]entry{ + name1: {Resource: nil}, + name2: {Resource: nil}, + }) + require.Equal(t, len(sentResponses[0].RemovedResources), 1) + require.Equal(t, len(sentResponses[1].RemovedResources), 1) + require.ElementsMatch(t, + []string{name1, name2}, + []string{sentResponses[0].RemovedResources[0], sentResponses[1].RemovedResources[0]}, + ) + + small1, small2, small3 := "a", "b", "c" + wayTooBig := strings.Repeat("3", 10*resourceSize) + + sentResponses = ds.chunk(map[string]entry{ + small1: {Resource: nil}, + small2: {Resource: nil}, + small3: {Resource: nil}, + wayTooBig: {Resource: nil}, + }) + require.Equal(t, len(sentResponses[0].RemovedResources), 3) + require.ElementsMatch(t, []string{small1, small2, small3}, sentResponses[0].RemovedResources) + require.Equal(t, int64(1), statsHandler.DeltaResourcesOverMaxSize.Load()) +} + +type customStatsHandler struct { + DeltaResourcesOverMaxSize atomic.Int64 `metric:",counter"` +} + +func (h *customStatsHandler) HandleServerEvent(ctx context.Context, event serverstats.Event) { + if _, ok := event.(*serverstats.ResourceOverMaxSize); ok { + h.DeltaResourcesOverMaxSize.Add(1) + } +} diff --git a/internal/server/handlers_test.go b/internal/server/handlers_test.go new file mode 100644 index 0000000..76e79a8 --- /dev/null +++ b/internal/server/handlers_test.go @@ -0,0 +1,231 @@ +package internal + +import ( + "maps" + "strconv" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/linkedin/diderot/ads" + "github.com/linkedin/diderot/internal/utils" + "github.com/linkedin/diderot/testutils" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +// TestHandlerDebounce checks the following: +// 1. That the handler does not invoke send as long as the debouncer has not allowed it to. +// 2. That updates that come in while send is being invoked do not get missed. +// 3. That if multiple updates for the same resource come in, only the latest one is respected. +func TestHandlerDebounce(t *testing.T) { + var released atomic.Bool + l := NewTestHandlerLimiter() + + var enterSendWg, continueSendWg sync.WaitGroup + continueSendWg.Add(1) + + actualResources := map[string]entry{} + + h := newHandler( + testutils.Context(t), + NoopLimiter{}, + l, + new(customStatsHandler), + false, + func(resources map[string]entry) error { + require.True(t, released.Swap(false), "send invoked without being released") + require.NotEmpty(t, resources) + enterSendWg.Done() + continueSendWg.Wait() + defer l.Done() + for k, e := range resources { + actualResources[k] = e + } + return nil + }, + ) + + // declare the various times upfront and ensure they are all unique, which will allow validating the interactions + // with the handler + var ( + fooSubscribedAt = time.Now() + fooCreateMetadata = ads.SubscriptionMetadata{ + SubscribedAt: fooSubscribedAt, + ModifiedAt: fooSubscribedAt.Add(2 * time.Hour), + CachedAt: fooSubscribedAt.Add(3 * time.Hour), + } + fooDeleteMetadata = ads.SubscriptionMetadata{ + SubscribedAt: fooSubscribedAt, + ModifiedAt: time.Time{}, + CachedAt: fooSubscribedAt.Add(4 * time.Hour), + } + + barCreateMetadata = ads.SubscriptionMetadata{ + SubscribedAt: fooSubscribedAt.Add(5 * time.Hour), + ModifiedAt: fooSubscribedAt.Add(6 * time.Hour), + CachedAt: fooSubscribedAt.Add(7 * time.Hour), + } + ) + + const foo, bar = "foo", "bar" + barR := new(ads.RawResource) + + h.Notify(foo, new(ads.RawResource), fooCreateMetadata) + h.Notify(foo, nil, fooDeleteMetadata) + + enterSendWg.Add(1) + go func() { + enterSendWg.Wait() + h.Notify(bar, barR, barCreateMetadata) + continueSendWg.Done() + }() + + released.Store(true) + l.Release() + require.Equal(t, + map[string]entry{ + foo: { + Resource: nil, + metadata: fooDeleteMetadata, + }, + }, + actualResources) + delete(actualResources, foo) + + enterSendWg.Add(1) + released.Store(true) + l.Release() + require.Equal( + t, + map[string]entry{ + bar: { + Resource: barR, + metadata: barCreateMetadata, + }, + }, + actualResources, + ) +} + +func TestHandlerBatching(t *testing.T) { + var released atomic.Bool + ch := make(chan map[string]entry) + granular := NewTestHandlerLimiter() + h := newHandler( + testutils.Context(t), + granular, + NoopLimiter{}, + new(customStatsHandler), + false, + func(resources map[string]entry) error { + // Double check that send isn't invoked before it's expected + if !released.Load() { + t.Fatalf("send invoked before release!") + } + ch <- maps.Clone(resources) + return nil + }, + ) + expectedEntries := make(map[string]entry) + notify := func() { + name := strconv.Itoa(len(expectedEntries)) + e := entry{ + Resource: nil, + metadata: ads.SubscriptionMetadata{}, + } + h.Notify(name, nil, e.metadata) + expectedEntries[name] = e + } + + h.StartNotificationBatch() + notify() + + for i := 0; i < 100; i++ { + notify() + } + released.Store(true) + h.EndNotificationBatch() + + require.Equal(t, expectedEntries, <-ch) + + released.Store(false) + + clear(expectedEntries) + notify() + granular.WaitForReserve() + + released.Store(true) + // Check that EndNotificationBatch skips the granular limiter + h.EndNotificationBatch() + + require.Equal(t, expectedEntries, <-ch) +} + +func TestHandlerDoesNothingOnEmptyBatch(t *testing.T) { + h := newHandler( + testutils.Context(t), + // Make both limiters nil, if the handler interacts with them at all the test should fail + nil, + nil, + new(customStatsHandler), + false, + func(_ map[string]entry) error { + require.Fail(t, "notify called") + return nil + }, + ) + h.StartNotificationBatch() + h.EndNotificationBatch() +} + +var ignoredMetadata = ads.SubscriptionMetadata{} + +func TestPseudoDeltaSotWHandler(t *testing.T) { + typeUrl := utils.GetTypeURL[*wrapperspb.BoolValue]() + // This test relies on Bool being a pseudo delta resource type, so fail the test early otherwise + require.True(t, utils.IsPseudoDeltaSotW(typeUrl)) + + l := NewTestHandlerLimiter() + var lastRes *ads.SotWDiscoveryResponse + h := newSotWHandler( + testutils.Context(t), + NoopLimiter{}, + l, + new(customStatsHandler), + typeUrl, + func(res *ads.SotWDiscoveryResponse) error { + defer l.Done() + lastRes = res + return nil + }, + ) + + const foo, bar, baz = "foo", "bar", "baz" + fooR := ads.NewResource(foo, "0", wrapperspb.Bool(true)) + barR := ads.NewResource(bar, "0", wrapperspb.Bool(false)) + bazR := ads.NewResource(baz, "0", wrapperspb.Bool(false)) + h.Notify(foo, testutils.MustMarshal(t, fooR), ignoredMetadata) + + l.Release() + require.Equal(t, typeUrl, lastRes.TypeUrl) + require.ElementsMatch(t, []*anypb.Any{testutils.MustMarshal(t, fooR).Resource}, lastRes.Resources) + + const wait = 500 * time.Millisecond + // PseudoDeltaSotW doesn't have a notion of deletions. A deleted resource simply never shows up again unless + // it's recreated. The next call to Release should therefore block until the handler invokes l.reserve(), which it + // should _not_ do until a resource is created. This test checks that that's the case by deleting foo then waiting + // creating bar 500ms before creating bar, then checking how long Release blocked, which should be roughly 500ms. + h.Notify(foo, nil, ignoredMetadata) + go func() { + time.Sleep(wait) + h.Notify(bar, testutils.MustMarshal(t, barR), ignoredMetadata) + }() + + start := time.Now() + l.Release() + require.WithinDuration(t, time.Now(), start.Add(wait), 10*time.Millisecond) + require.ElementsMatch(t, []*anypb.Any{testutils.MustMarshal(t, bazR).Resource}, lastRes.Resources) +} diff --git a/internal/server/limiter.go b/internal/server/limiter.go new file mode 100644 index 0000000..f1f2e79 --- /dev/null +++ b/internal/server/limiter.go @@ -0,0 +1,33 @@ +package internal + +import ( + "time" + + "golang.org/x/time/rate" +) + +// handlerLimiter is an interface used by the handler implementation. It exists for the sole purpose +// of testing and is trivially implemented by rate.Limiter using rateLimiterWrapper. It is not +// exposed in this package's public API. +type handlerLimiter interface { + // reserve returns a channel that will receive the current time (or be closed) once the rate limit + // clears. Callers should wait until this occurs before acting. The returned cancel function should + // be invoked if the caller did not wait for the rate limit to clear, though it is safe to call even + // if after the rate limit cleared. In other words, it is safe to invoke in a deferred expression. + reserve() (reservation <-chan time.Time, cancel func()) +} + +var _ handlerLimiter = (*rateLimiterWrapper)(nil) + +// rateLimiterWrapper implements handlerLimiter using a rate.Limiter +type rateLimiterWrapper rate.Limiter + +func (w *rateLimiterWrapper) reserve() (reservation <-chan time.Time, cancel func()) { + r := (*rate.Limiter)(w).Reserve() + timer := time.NewTimer(r.Delay()) + return timer.C, func() { + // Stopping the timer cleans up any goroutines or schedules associated with this timer. Invoking this + // after the timer fires is a noop. + timer.Stop() + } +} diff --git a/internal/server/limiter_test.go b/internal/server/limiter_test.go new file mode 100644 index 0000000..b8f5380 --- /dev/null +++ b/internal/server/limiter_test.go @@ -0,0 +1,86 @@ +package internal + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/time/rate" +) + +var _ handlerLimiter = (*TestRateLimiter)(nil) + +func NewTestHandlerLimiter() *TestRateLimiter { + return &TestRateLimiter{ + cond: sync.NewCond(new(sync.Mutex)), + } +} + +type TestRateLimiter struct { + cond *sync.Cond + ch chan time.Time + wg sync.WaitGroup +} + +func (l *TestRateLimiter) reserve() (<-chan time.Time, func()) { + l.cond.L.Lock() + defer l.cond.L.Unlock() + + if l.ch == nil { + l.ch = make(chan time.Time) + l.cond.Signal() + } + return l.ch, func() {} +} + +func (l *TestRateLimiter) Release() { + l.cond.L.Lock() + if l.ch == nil { + l.cond.Wait() + } + ch := l.ch + l.ch = nil + l.cond.L.Unlock() + + l.wg.Add(1) + ch <- time.Now() + l.wg.Wait() +} + +// WaitForReserve waits for another goroutine to call reserve (if one hasn't already) +func (l *TestRateLimiter) WaitForReserve() { + l.cond.L.Lock() + defer l.cond.L.Unlock() + + if l.ch == nil { + l.cond.Wait() + } +} + +func (l *TestRateLimiter) Done() { + l.wg.Done() +} + +func TestHandlerLimiter(t *testing.T) { + l := (*rateLimiterWrapper)(rate.NewLimiter(10, 1)) + start := time.Now() + ch1, _ := l.reserve() + ch2, _ := l.reserve() + + const delta = float64(5 * time.Millisecond) + require.InDelta(t, 0, (<-ch1).Sub(start), delta) + require.InDelta(t, 100*time.Millisecond, (<-ch2).Sub(start), delta) +} + +var closedReservation = func() chan time.Time { + ch := make(chan time.Time) + close(ch) + return ch +}() + +type NoopLimiter struct{} + +func (n NoopLimiter) reserve() (reservation <-chan time.Time, cancel func()) { + return closedReservation, func() {} +} diff --git a/internal/server/subscription_manager.go b/internal/server/subscription_manager.go new file mode 100644 index 0000000..72b9643 --- /dev/null +++ b/internal/server/subscription_manager.go @@ -0,0 +1,216 @@ +package internal + +import ( + "context" + "sync" + + "github.com/linkedin/diderot/ads" + "github.com/linkedin/diderot/internal/utils" + "google.golang.org/protobuf/proto" +) + +// ResourceLocator is a copy of the interface in the root package, to avoid import cycles. +type ResourceLocator interface { + Subscribe( + streamCtx context.Context, + typeURL, resourceName string, + handler ads.RawSubscriptionHandler, + ) (unsubscribe func()) + Resubscribe( + streamCtx context.Context, + typeURL, resourceName string, + handler ads.RawSubscriptionHandler, + ) +} + +type SubscriptionManager[REQ proto.Message] interface { + // ProcessSubscriptions handles subscribing/unsubscribing from the resources provided in the given + // xDS request. This function will always invoke BatchSubscriptionHandler.StartNotificationBatch + // before it starts processing the subscriptions and always complete with + // BatchSubscriptionHandler.EndNotificationBatch. Since the cache implementation always notifies the + // SubscriptionHandler with the current value of the subscribed resource, + // BatchSubscriptionHandler.EndNotificationBatch will be invoked after the handler has been notified + // of all the resources requested. + ProcessSubscriptions(REQ) + // IsSubscribedTo checks whether the client has subscribed to the given resource name. + IsSubscribedTo(name string) bool + // UnsubscribeAll cleans up any active subscriptions and disables the wildcard subscription if enabled. + UnsubscribeAll() +} + +// subscriptionManagerCore keeps track of incoming subscription and unsubscription requests, and +// executes the corresponding actions against the underlying cache. It is meant to be embedded in +// deltaSubscriptionManager and sotWSubscriptionManager to deduplicate the subscription tracking +// logic. +type subscriptionManagerCore struct { + ctx context.Context + locator ResourceLocator + typeURL string + handler BatchSubscriptionHandler + + lock sync.Mutex + subscriptions map[string]func() +} + +func newSubscriptionManagerCore( + ctx context.Context, + locator ResourceLocator, + typeURL string, + handler BatchSubscriptionHandler, +) *subscriptionManagerCore { + c := &subscriptionManagerCore{ + ctx: ctx, + locator: locator, + typeURL: typeURL, + handler: handler, + subscriptions: make(map[string]func()), + } + // Ensure all the subscriptions managed by this subscription manager are cleaned up, otherwise they + // will dangle forever in the cache and prevent the backing SubscriptionHandler from being collected + // as well. + context.AfterFunc(ctx, func() { + c.UnsubscribeAll() + }) + return c +} + +type deltaSubscriptionManager struct { + *subscriptionManagerCore + firstCallReceived bool +} + +// NewDeltaSubscriptionManager creates a new SubscriptionManager specifically designed to handle the +// Delta xDS protocol's subscription semantics. +func NewDeltaSubscriptionManager( + ctx context.Context, + locator ResourceLocator, + typeURL string, + handler BatchSubscriptionHandler, +) SubscriptionManager[*ads.DeltaDiscoveryRequest] { + return &deltaSubscriptionManager{ + subscriptionManagerCore: newSubscriptionManagerCore(ctx, locator, typeURL, handler), + } +} + +type sotWSubscriptionManager struct { + *subscriptionManagerCore + receivedExplicitSubscriptions bool +} + +// NewSotWSubscriptionManager creates a new SubscriptionManager specifically designed to handle the +// State-of-the-World xDS protocol's subscription semantics. +func NewSotWSubscriptionManager( + ctx context.Context, + locator ResourceLocator, + typeURL string, + handler BatchSubscriptionHandler, +) SubscriptionManager[*ads.SotWDiscoveryRequest] { + return &sotWSubscriptionManager{ + subscriptionManagerCore: newSubscriptionManagerCore(ctx, locator, typeURL, handler), + } +} + +// ProcessSubscriptions processes the subscriptions for a delta stream. It manages the implicit +// wildcard subscription outlined in [the spec]. The server should default to the wildcard +// subscription if the client's first request does not provide any resource names to explicitly +// subscribe to. The client must then explicit unsubscribe from the wildcard. Subsequent requests +// that do not provide any explicit resource names will not alter the current subscription state. +// +// [the spec]: https://www.envoyproxy.io/docs/envoy/latest/api-docs/xds_protocol.html#how-the-client-specifies-what-resources-to-return +func (m *deltaSubscriptionManager) ProcessSubscriptions(req *ads.DeltaDiscoveryRequest) { + m.handler.StartNotificationBatch() + defer m.handler.EndNotificationBatch() + + m.lock.Lock() + defer m.lock.Unlock() + + subscribe := req.ResourceNamesSubscribe + if !m.firstCallReceived { + m.firstCallReceived = true + if len(subscribe) == 0 { + subscribe = []string{ads.WildcardSubscription} + } + } + + for _, name := range subscribe { + m.subscribe(name) + } + + for _, name := range req.ResourceNamesUnsubscribe { + m.unsubscribe(name) + } +} + +// ProcessSubscriptions processes the subscriptions for a state of the world stream. It manages the +// implicit wildcard subscription outlined in [the spec]. The server should default to the wildcard +// subscription if the client has not sent any resource names to explicitly subscribe to. After the +// first request that provides explicit resource names, the implicit wildcard subscription should +// disappear. +// +// [the spec]: https://www.envoyproxy.io/docs/envoy/latest/api-docs/xds_protocol.html#how-the-client-specifies-what-resources-to-return +func (m *sotWSubscriptionManager) ProcessSubscriptions(req *ads.SotWDiscoveryRequest) { + m.handler.StartNotificationBatch() + defer m.handler.EndNotificationBatch() + + m.lock.Lock() + defer m.lock.Unlock() + + subscribe := req.ResourceNames + m.receivedExplicitSubscriptions = m.receivedExplicitSubscriptions || len(subscribe) != 0 + if !m.receivedExplicitSubscriptions { + subscribe = []string{ads.WildcardSubscription} + } + + intersection := utils.Set[string]{} + for _, name := range subscribe { + if _, ok := m.subscriptions[name]; ok { + intersection.Add(name) + } + } + + for name := range m.subscriptions { + if !intersection.Contains(name) { + m.unsubscribe(name) + } + } + + for _, name := range subscribe { + if !intersection.Contains(name) { + m.subscribe(name) + } + } +} + +func (c *subscriptionManagerCore) IsSubscribedTo(name string) bool { + c.lock.Lock() + defer c.lock.Unlock() + + _, nameOk := c.subscriptions[name] + _, wildcardOk := c.subscriptions[ads.WildcardSubscription] + return nameOk || wildcardOk +} + +func (c *subscriptionManagerCore) UnsubscribeAll() { + c.lock.Lock() + defer c.lock.Unlock() + + for name := range c.subscriptions { + c.unsubscribe(name) + } +} + +func (c *subscriptionManagerCore) subscribe(name string) { + _, ok := c.subscriptions[name] + if !ok { + c.subscriptions[name] = c.locator.Subscribe(c.ctx, c.typeURL, name, c.handler) + } else { + c.locator.Resubscribe(c.ctx, c.typeURL, name, c.handler) + } +} + +func (c *subscriptionManagerCore) unsubscribe(name string) { + if unsub, ok := c.subscriptions[name]; ok { + unsub() + delete(c.subscriptions, name) + } +} diff --git a/internal/utils/set.go b/internal/utils/set.go new file mode 100644 index 0000000..ec62521 --- /dev/null +++ b/internal/utils/set.go @@ -0,0 +1,44 @@ +package utils + +import ( + "fmt" + "maps" + "slices" +) + +type Set[T comparable] map[T]struct{} + +func NewSet[T comparable](elements ...T) Set[T] { + s := make(Set[T], len(elements)) + for _, t := range elements { + s.Add(t) + } + return s +} + +func (s Set[T]) Add(t T) bool { + _, ok := s[t] + if ok { + return false + } + s[t] = struct{}{} + return true +} + +func (s Set[T]) Contains(t T) bool { + _, ok := s[t] + return ok +} + +func (s Set[T]) Remove(t T) bool { + _, ok := s[t] + if !ok { + return false + } + delete(s, t) + return true +} + +func (s Set[T]) String() string { + return fmt.Sprint(slices.Collect(maps.Keys(s))) +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go new file mode 100644 index 0000000..42d3029 --- /dev/null +++ b/internal/utils/utils.go @@ -0,0 +1,125 @@ +package utils + +import ( + "encoding/base64" + "slices" + "strconv" + "strings" + "time" + + types "github.com/envoyproxy/go-control-plane/pkg/resource/v3" + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" +) + +// NonceLength is the length of the string returned by NewNonce. NewNonce encodes the current UNIX +// time in nanos in hex encoding, so the nonce will be 16 characters if the current UNIX nano time is +// greater than 2^60-1. This is because it takes 16 hex characters to encode 64 bits, but only 15 to +// encode 60 bits (the output of strconv.FormatInt is not padded by 0s). 2^60-1 nanos from epoch time +// (January 1st 1970) is 2006-07-14 23:58:24.606, which as of this writing is over 17 years ago. This +// is why it's guaranteed that NonceLength will be 16 characters (before that date, encoding the +// nanos only required 15 characters). For the curious, the UNIX nano timestamp will overflow int64 +// some time in 2262, making this constant valid for the next few centuries. +const NonceLength = 16 + +// NewNonce creates a new unique nonce based on the current UNIX time in nanos. It always returns a +// string of length NonceLength. +func NewNonce() string { + // The second parameter to FormatInt is the base, e.g. 2 will return binary, 8 will return octal + // encoding, etc. 16 means FormatInt returns the integer in hex encoding, e.g. 30 => "1e" or + // 1704239351400 => "18ccc94c668". + const hexBase = 16 + return strconv.FormatInt(time.Now().UnixNano(), hexBase) +} + +func GetTypeURL[T proto.Message]() string { + var t T + return getTypeURL(t) +} + +func getTypeURL(t proto.Message) string { + return types.APITypePrefix + string(t.ProtoReflect().Descriptor().FullName()) +} + +// MapToProto serializes the given map using protobuf. It sorts the entries based on the key such that the same map +// always produces the same output. It then encodes the entries by appending the key then value, and b64 encodes the +// entire output. Note that the final b64 encoding step is critical as this function is intended to be used with +// [ads.SotWDiscoveryResponse.Version], which is a string field. In protobuf, string fields must contain valid UTF-8 +// characters, and b64 encoding ensures that. +func MapToProto(m map[string]string) string { + if len(m) == 0 { + return "" + } + + var b []byte + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + slices.Sort(keys) + + for _, k := range keys { + b = protowire.AppendString(b, k) + b = protowire.AppendString(b, m[k]) + } + return base64.StdEncoding.EncodeToString(b) +} + +// ProtoToMap is the inverse of MapToProto. It returns an error on any decoding or deserialization issues. +func ProtoToMap(s string) (map[string]string, error) { + if s == "" { + return nil, nil + } + + b, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return nil, err + } + m := make(map[string]string) + + parse := func() (string, error) { + s, n := protowire.ConsumeString(b) + if n < 0 { + return "", protowire.ParseError(n) + } + b = b[n:] + return s, nil + } + + for len(b) > 0 { + k, err := parse() + if err != nil { + return nil, err + } + v, err := parse() + if err != nil { + return nil, err + } + m[k] = v + } + + return m, nil +} + +// IsPseudoDeltaSotW checks whether the given resource type url is intended to behave as a "pseudo +// delta" resource. Instead of sending the entire state of the world for every resource change, the +// server is expected to only send the changed resource. From [the spec]: +// +// In the SotW protocol variants, all resource types except for Listener and Cluster are grouped into +// responses in the same way as in the incremental protocol variants. However, Listener and Cluster +// resource types are handled differently: the server must include the complete state of the world, +// meaning that all resources of the relevant type that are needed by the client must be included, +// even if they did not change since the last response. +// +// In other words, for everything except Listener and Cluster, the server should only send the +// changed resources, rather than every resource every time. +// +// [the spec]: https://www.envoyproxy.io/docs/envoy/latest/api-docs/xds_protocol#grouping-resources-into-responses +func IsPseudoDeltaSotW(typeURL string) bool { + return !(typeURL == types.ListenerType || typeURL == types.ClusterType) +} + +// TrimTypeURL removes the leading "types.googleapis.com/" prefix from the given string. +func TrimTypeURL(typeURL string) string { + return strings.TrimPrefix(typeURL, types.APITypePrefix) +} diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go new file mode 100644 index 0000000..d959668 --- /dev/null +++ b/internal/utils/utils_test.go @@ -0,0 +1,51 @@ +package utils + +import ( + "testing" + + "github.com/envoyproxy/go-control-plane/pkg/resource/v3" + "github.com/linkedin/diderot/ads" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protowire" +) + +func TestGetTypeURL(t *testing.T) { + require.Equal(t, resource.ListenerType, GetTypeURL[*ads.Listener]()) + require.Equal(t, resource.EndpointType, GetTypeURL[*ads.Endpoint]()) + require.Equal(t, resource.ClusterType, GetTypeURL[*ads.Cluster]()) + require.Equal(t, resource.RouteType, GetTypeURL[*ads.Route]()) +} + +func TestProtoMap(t *testing.T) { + t.Run("good", func(t *testing.T) { + m := map[string]string{ + "foo": "bar", + "baz": "qux", + "empty": "", + "": "empty", + } + s := MapToProto(m) + m2, err := ProtoToMap(s) + require.NoError(t, err) + require.Equal(t, m, m2) + + // Check that on a different invocation, the output remains the same + require.Equal(t, s, MapToProto(m)) + + m2, err = ProtoToMap("") + require.NoError(t, err) + require.Empty(t, m2) + }) + t.Run("bad", func(t *testing.T) { + _, err := ProtoToMap("1") + require.Error(t, err) + + b := protowire.AppendString(nil, "foo") + _, err = ProtoToMap(string(b)) + require.Error(t, err) + }) +} + +func TestNonceLength(t *testing.T) { + require.Len(t, NewNonce(), NonceLength) +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..b8980a7 --- /dev/null +++ b/server.go @@ -0,0 +1,479 @@ +package diderot + +import ( + "context" + "log/slog" + "sync" + "time" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + discovery "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" + "github.com/linkedin/diderot/ads" + internal "github.com/linkedin/diderot/internal/server" + "github.com/linkedin/diderot/internal/utils" + serverstats "github.com/linkedin/diderot/stats/server" + "golang.org/x/time/rate" + grpcStatus "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/protobuf/proto" +) + +var _ ads.Server = (*ADSServer)(nil) + +// An ADSServer is an implementation of the xDS protocol. It implements the tricky parts of an xDS +// control plane such as managing subscriptions, parsing the incoming [ads.SotWDiscoveryRequest] and +// [ads.DeltaDiscoveryRequest], etc. The actual business logic of locating the resources is injected +// via the given ResoureLocator. +type ADSServer struct { + discovery.UnimplementedAggregatedDiscoveryServiceServer + + locator ResourceLocator + + requestLimiter *rate.Limiter + globalLimiter *rate.Limiter + statsHandler serverstats.Handler + maxDeltaResponseSize int + controlPlane *corev3.ControlPlane + + granularLimitLock sync.Mutex + granularLimit rate.Limit + granularLimiters utils.Set[*rate.Limiter] +} + +// NewADSServer creates a new [*ADSServer] with the given options. +func NewADSServer(locator ResourceLocator, options ...ADSServerOption) *ADSServer { + s := &ADSServer{ + locator: locator, + + requestLimiter: rate.NewLimiter(rate.Inf, 1), + globalLimiter: rate.NewLimiter(rate.Inf, 1), + granularLimit: rate.Inf, + granularLimiters: utils.NewSet[*rate.Limiter](), + } + for _, opt := range options { + opt.apply(s) + } + + return s +} + +// ADSServerOption configures how the ADS Server is initialized. +type ADSServerOption interface { + apply(s *ADSServer) +} + +type serverOption func(s *ADSServer) + +func (f serverOption) apply(s *ADSServer) { + f(s) +} + +// defaultLimit interprets the given limit according to the documentation outlined in the various +// WithLimit options, i.e. if the given limit is negative, 0 or rate.Inf, it returns rate.Inf which, +// if given to a rate.Limiter disables the rate limiting. +func defaultLimit(limit rate.Limit) rate.Limit { + if limit <= 0 { + return rate.Inf + } + return limit +} + +// WithRequestRateLimit sets the rate limiting parameters for client requests. When a client's +// request is being limited, it will block all other requests for that client until the rate limiting +// expires. If not specified, 0 or rate.Inf is provided, this feature is disabled. +func WithRequestRateLimit(limit rate.Limit) ADSServerOption { + return serverOption(func(s *ADSServer) { + s.SetRequestRateLimit(limit) + }) +} + +// WithGlobalResponseRateLimit enforces a maximum rate at which the server will respond to clients. +// This prevents clients from being overloaded with responses and throttles the resource consumption +// on the server. If not specified, 0 or rate.Inf is provided, this feature is disabled. +func WithGlobalResponseRateLimit(globalLimit rate.Limit) ADSServerOption { + return serverOption(func(s *ADSServer) { + s.SetGlobalResponseRateLimit(globalLimit) + }) +} + +// WithGranularResponseRateLimit is an additional layer of rate limiting to the one provided by +// WithGlobalResponseRateLimit. If specified, it will be applied to each resource type requested by +// each client. For example, a client can receive updates to its LDS, RDS, CDS and EDS subscriptions +// at a rate of 10 responses per second per type, for a potential maximum rate of 40 responses per +// second since it is subscribed to 4 individual types. When determining how long a response should +// be stalled however, the server computes the wait time required to satisfy both limits and picks +// the largest one. This means this granular limit cannot override the global limit. If not +// specified, this feature is disabled. +func WithGranularResponseRateLimit(granularLimit rate.Limit) ADSServerOption { + return serverOption(func(s *ADSServer) { + s.SetGranularResponseRateLimit(granularLimit) + }) +} + +// WithServerStatsHandler registers a stats handler for the server. The given handler will be invoked +// whenever a corresponding event happens. See the [stats] package for more details. +func WithServerStatsHandler(statsHandler serverstats.Handler) ADSServerOption { + return serverOption(func(s *ADSServer) { + s.statsHandler = statsHandler + }) +} + +// WithMaxDeltaResponseSize limits the size of responses sent by the server when the Delta variant of the xDS protocol +// is being used. As it builds the response from the set of resource updates it wants to send, the server will check +// how large the serialized message will be, stopping before it reaches the threshold. It then sends the chunk it +// has built up until this point before restarting the process over until the desired set of updates is sent. Note that +// this cannot be implemented for SotW protocols due to the nature of the protocol itself. +// The configuration is ignored if 0 and is disabled by default. +func WithMaxDeltaResponseSize(maxResponseSize int) ADSServerOption { + return serverOption(func(s *ADSServer) { + s.maxDeltaResponseSize = maxResponseSize + }) +} + +// WithControlPlane causes the server to include the given corev3.ControlPlane instance in each response. +func WithControlPlane(controlPlane *corev3.ControlPlane) ADSServerOption { + return serverOption(func(s *ADSServer) { + s.controlPlane = controlPlane + }) +} + +// SetRequestRateLimit updates the incoming request rate limit. If the given limit is 0, negative or +// [rate.Inf], it disables the rate limiting. +func (s *ADSServer) SetRequestRateLimit(newLimit rate.Limit) { + s.requestLimiter.SetLimit(defaultLimit(newLimit)) +} + +// GetRequestRateLimit returns the current incoming request rate limit. +func (s *ADSServer) GetRequestRateLimit() rate.Limit { + return s.requestLimiter.Limit() +} + +// SetGlobalResponseRateLimit updates the global response rate limit. If the given limit is 0, +// negative or [rate.Inf], it disables the rate limiting. +func (s *ADSServer) SetGlobalResponseRateLimit(newLimit rate.Limit) { + s.globalLimiter.SetLimit(defaultLimit(newLimit)) +} + +// GetGlobalResponseRateLimit returns the current global response rate limit. +func (s *ADSServer) GetGlobalResponseRateLimit() rate.Limit { + return s.globalLimiter.Limit() +} + +// SetGranularResponseRateLimit updates the granular response rate limit. If the given limit is 0, +// negative or [rate.Inf], it disables the rate limiting. +func (s *ADSServer) SetGranularResponseRateLimit(newLimit rate.Limit) { + s.granularLimitLock.Lock() + defer s.granularLimitLock.Unlock() + s.granularLimit = defaultLimit(newLimit) + for l := range s.granularLimiters { + l.SetLimit(s.granularLimit) + } +} + +func (s *ADSServer) newGranularRateLimiter() *rate.Limiter { + s.granularLimitLock.Lock() + defer s.granularLimitLock.Unlock() + + l := rate.NewLimiter(s.granularLimit, 1) + s.granularLimiters.Add(l) + return l +} + +// GetGranularResponseRateLimit returns the current granular response rate limit. +func (s *ADSServer) GetGranularResponseRateLimit() rate.Limit { + return s.granularLimit +} + +// StreamAggregatedResources is the implementation of the state-of-the-world variant of the ADS protocol. +func (s *ADSServer) StreamAggregatedResources(stream ads.SotWStream) (err error) { + h := &streamHandler[*ads.SotWDiscoveryRequest, *ads.SotWDiscoveryResponse]{ + server: s, + stream: stream, + streamType: ads.SotWStreamType, + newHandler: func( + ctx context.Context, + granularLimiter *rate.Limiter, + statsHandler serverstats.Handler, + typeUrl string, + send func(*ads.SotWDiscoveryResponse) error, + ) internal.BatchSubscriptionHandler { + return internal.NewSotWHandler( + ctx, + granularLimiter, + s.globalLimiter, + statsHandler, + typeUrl, + send, + ) + }, + newManager: internal.NewSotWSubscriptionManager, + noSuchTypeResponse: func(req *ads.SotWDiscoveryRequest) *ads.SotWDiscoveryResponse { + return &ads.SotWDiscoveryResponse{ + Resources: nil, + TypeUrl: req.TypeUrl, + Nonce: utils.NewNonce(), + } + }, + setControlPlane: func(res *ads.SotWDiscoveryResponse, controlPlane *corev3.ControlPlane) { + res.ControlPlane = controlPlane + }, + } + + return h.loop() +} + +// DeltaAggregatedResources is the implementation of the delta/incremental variant of the ADS +// protocol. +func (s *ADSServer) DeltaAggregatedResources(stream ads.DeltaStream) (err error) { + h := &streamHandler[*ads.DeltaDiscoveryRequest, *ads.DeltaDiscoveryResponse]{ + server: s, + stream: stream, + streamType: ads.DeltaStreamType, + // TODO: respect the initial_resource_versions map instead of sending everything every time + newHandler: func( + ctx context.Context, + responseLimiter *rate.Limiter, + statsHandler serverstats.Handler, + typeUrl string, + send func(*ads.DeltaDiscoveryResponse) error, + ) internal.BatchSubscriptionHandler { + return internal.NewDeltaHandler( + ctx, + responseLimiter, + s.globalLimiter, + statsHandler, + s.maxDeltaResponseSize, + typeUrl, + send, + ) + }, + newManager: internal.NewDeltaSubscriptionManager, + noSuchTypeResponse: func(req *ads.DeltaDiscoveryRequest) *ads.DeltaDiscoveryResponse { + return &ads.DeltaDiscoveryResponse{ + TypeUrl: req.GetTypeUrl(), + RemovedResources: req.GetResourceNamesSubscribe(), + Nonce: utils.NewNonce(), + ControlPlane: s.controlPlane, + } + }, + setControlPlane: func(res *ads.DeltaDiscoveryResponse, controlPlane *corev3.ControlPlane) { + res.ControlPlane = controlPlane + }, + } + + return h.loop() +} + +type adsDiscoveryRequest interface { + proto.Message + GetTypeUrl() string + GetResponseNonce() string + GetErrorDetail() *grpcStatus.Status + GetNode() *ads.Node +} + +type adsStream[REQ adsDiscoveryRequest, RES proto.Message] interface { + Context() context.Context + Recv() (REQ, error) + Send(RES) error +} + +// streamHandler captures the various elements required to handle an ADS stream. +type streamHandler[REQ adsDiscoveryRequest, RES proto.Message] struct { + sendLock sync.Mutex + + server *ADSServer + stream adsStream[REQ, RES] + streamCtx context.Context + streamType ads.StreamType + newHandler func( + ctx context.Context, + granularLimiter *rate.Limiter, + statsHandler serverstats.Handler, + typeUrl string, + send func(RES) error, + ) internal.BatchSubscriptionHandler + newManager func( + ctx context.Context, + locator internal.ResourceLocator, + typeURL string, + handler internal.BatchSubscriptionHandler, + ) internal.SubscriptionManager[REQ] + noSuchTypeResponse func(req REQ) RES + setControlPlane func(res RES, controlPlane *corev3.ControlPlane) + aggregateSubscriptions map[string]internal.SubscriptionManager[REQ] +} + +// send invokes Send on the stream with the given response, returning an error if Send returns an error. Crucially, +// Send can only be invoked by one goroutine at a time, so this function protects the invocation of Send with sendLock. +func (h *streamHandler[REQ, RES]) send(res RES) (err error) { + if h.server.statsHandler != nil { + start := time.Now() + defer func() { + h.server.statsHandler.HandleServerEvent(h.streamCtx, &serverstats.ResponseSent{ + Res: res, + Duration: time.Since(start), + }) + }() + } + + h.sendLock.Lock() + defer h.sendLock.Unlock() + h.setControlPlane(res, h.server.controlPlane) + slog.DebugContext(h.streamCtx, "Sending", "msg", res) + return h.stream.Send(res) +} + +func (h *streamHandler[REQ, RES]) recv() (REQ, error) { + // TODO: Introduce a timeout on receiving the first request. In order to keep a stream alive, gRPC needs to send + // keepalives etc. If a client never sends the first request to identify itself etc it should eventually be kicked + // since it is wasting resources. + return h.stream.Recv() +} + +// getSubscriptionManager returns a [internal.SubscriptionManager] for the given type url. If the +// type is not supported (checked via the ResourceLocator), this function returns nil, false. This +// indicates that the given type is unknown by the system and the request should be ignored. +// Subsequent calls to this function with the same type url always return the same subscription +// manager. +func (h *streamHandler[REQ, RES]) getSubscriptionManager( + typeURL string, +) (internal.SubscriptionManager[REQ], bool) { + // Manager was already created, return immediately. + if manager, ok := h.aggregateSubscriptions[typeURL]; ok { + return manager, true + } + + if !h.server.locator.IsTypeSupported(h.streamCtx, typeURL) { + return nil, false + } + + manager := h.newManager( + h.streamCtx, + h.server.locator, + typeURL, + h.newHandler( + h.streamCtx, + h.server.newGranularRateLimiter(), + h.server.statsHandler, + typeURL, + h.send, + ), + ) + + h.aggregateSubscriptions[typeURL] = manager + return manager, true +} + +func (h *streamHandler[REQ, RES]) loop() error { + for { + req, err := h.recv() + if err != nil { + return err + } + + // initialize the stream context with the node on the first request + if h.streamCtx == nil { + h.streamCtx = context.WithValue(h.stream.Context(), nodeContextKey{}, req.GetNode()) + } + + err = h.handleRequest(req) + if err != nil { + return err + } + } +} + +func (h *streamHandler[REQ, RES]) handleRequest(req REQ) (err error) { + slog.DebugContext(h.streamCtx, "Received request", "req", req) + + var stat *serverstats.RequestReceived + if h.server.statsHandler != nil { + start := time.Now() + stat = &serverstats.RequestReceived{Req: req} + defer func() { + stat.Duration = time.Since(start) + h.server.statsHandler.HandleServerEvent(h.streamCtx, stat) + }() + } + + err = h.server.requestLimiter.Wait(h.streamCtx) + if err != nil { + return err + } + + if h.aggregateSubscriptions == nil { + h.aggregateSubscriptions = make(map[string]internal.SubscriptionManager[REQ]) + } + + typeURL := req.GetTypeUrl() + manager, ok := h.getSubscriptionManager(typeURL) + if !ok { + slog.WarnContext(h.streamCtx, "Ignoring unknown requested type", "typeURL", typeURL, "req", req) + if stat != nil { + stat.IsRequestedTypeUnknown = true + } + return h.send(h.noSuchTypeResponse(req)) + } + + switch { + case req.GetErrorDetail() != nil: + slog.WarnContext(h.streamCtx, "Got client NACK", "req", req) + if stat != nil { + stat.IsNACK = true + } + case req.GetResponseNonce() != "": + slog.DebugContext(h.streamCtx, "ACKED", "req", req) + if stat != nil { + stat.IsACK = true + } + } + + manager.ProcessSubscriptions(req) + + return nil + +} + +// The ResourceLocator abstracts away the business logic used to locate resources and subscribe to +// them. For example, while Subscribe is trivially implemented with a [Cache] which only serves +// static predetermined resources, it could be implemented to instead generate a resource definition +// on the fly, based on the client's attributes. Alternatively, some attribute in the client's +// [ads.Node] may show that the client does not support IPv6 and should instead be shown IPv4 +// addresses in the [ads.Endpoint] response. +// +// Many users of this library may also choose to implement a +// [google.golang.org/grpc.StreamServerInterceptor] to populate additional values in the stream's +// context, which can be used to better identify the client. However, for convenience, the [ads.Node] +// provided in the request will always be provided in the stream context, and can be accessed with +// [NodeFromContext]. +type ResourceLocator interface { + // IsTypeSupported is used to check whether the given client supports the requested type. + IsTypeSupported(streamCtx context.Context, typeURL string) bool + // Subscribe subscribes the given handler to the desired resource. The returned function should + // execute the unsubscription to the resource. It is guaranteed that the desired type has been + // checked via IsTypeSupported, and that therefore it is supported. + Subscribe( + streamCtx context.Context, + typeURL, resourceName string, + handler ads.RawSubscriptionHandler, + ) (unsubscribe func()) + // Resubscribe will be called whenever a client resubscribes to a given resource. The xDS protocol + // dictates that re-subscribing to a resource should cause the server to re-send the resource. Note + // that implementations of this interface that leverage a [Cache] already support this behavior + // out-of-the-box. + Resubscribe( + streamCtx context.Context, + typeURL, resourceName string, + handler ads.RawSubscriptionHandler, + ) +} + +type nodeContextKey struct{} + +// NodeFromContext returns the [ads.Node] in the given context, if it exists. Note that the +// [ADSServer] will always provide the Node in the context when invoking methods on the +// [ResourceLocator]. +func NodeFromContext(streamCtx context.Context) (*ads.Node, bool) { + node, ok := streamCtx.Value(nodeContextKey{}).(*ads.Node) + return node, ok +} diff --git a/server_test.go b/server_test.go new file mode 100644 index 0000000..6d52225 --- /dev/null +++ b/server_test.go @@ -0,0 +1,984 @@ +package diderot + +import ( + "context" + "crypto/rand" + "encoding/json" + "net" + "os" + "sync" + "sync/atomic" + "testing" + "time" + + core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + endpoint "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + discovery "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" + "github.com/linkedin/diderot/ads" + internal "github.com/linkedin/diderot/internal/server" + "github.com/linkedin/diderot/internal/utils" + serverstats "github.com/linkedin/diderot/stats/server" + "github.com/linkedin/diderot/testutils" + "github.com/stretchr/testify/require" + "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/xds" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +var ( + badTypeURL = "foobar" + badResources = []string{"badResource1", "badResource2"} + controlPlane = &core.ControlPlane{Identifier: "fooBar"} +) + +type serverStatsHandler struct { + UnknownTypes atomic.Int64 + UnknownResources atomic.Int64 + NACKsReceived atomic.Int64 + ACKsReceived atomic.Int64 +} + +func (m *serverStatsHandler) HandleServerEvent(ctx context.Context, event serverstats.Event) { + switch e := event.(type) { + case *serverstats.RequestReceived: + if e.IsNACK { + m.NACKsReceived.Add(1) + } + if e.IsACK { + m.ACKsReceived.Add(1) + } + if e.IsRequestedTypeUnknown { + m.UnknownTypes.Add(1) + } + case *serverstats.ResourceQueued: + if !e.ResourceExists { + m.UnknownResources.Add(1) + } + } +} + +func (m *serverStatsHandler) reset() { + m.UnknownTypes.Store(0) + m.UnknownResources.Store(0) + m.NACKsReceived.Store(0) + m.ACKsReceived.Store(0) +} + +type testLocator struct { + t *testing.T + node *ads.Node + caches map[string]RawCache +} + +func (tl *testLocator) checkContextNode(streamCtx context.Context) { + if tl.node == nil { + // skip checking node if nil + return + } + node, ok := NodeFromContext(streamCtx) + require.True(tl.t, ok) + testutils.ProtoEquals(tl.t, tl.node, node) +} + +func (tl *testLocator) IsTypeSupported(streamCtx context.Context, typeURL string) bool { + tl.checkContextNode(streamCtx) + _, ok := tl.caches[typeURL] + return ok +} + +func (tl *testLocator) Subscribe( + streamCtx context.Context, + typeURL, resourceName string, + handler ads.RawSubscriptionHandler, +) (unsubscribe func()) { + tl.checkContextNode(streamCtx) + + c := tl.caches[typeURL] + Subscribe(c, resourceName, handler) + return func() { + Unsubscribe(c, resourceName, handler) + } +} + +func (tl *testLocator) Resubscribe( + streamCtx context.Context, + typeURL, resourceName string, + handler ads.RawSubscriptionHandler, +) { + tl.checkContextNode(streamCtx) + Subscribe(tl.caches[typeURL], resourceName, handler) +} + +func newTestLocator(t *testing.T, node *ads.Node, types ...Type) *testLocator { + tl := &testLocator{ + t: t, + node: node, + caches: make(map[string]RawCache), + } + for _, tpe := range types { + tl.caches[tpe.URL()] = tpe.NewCache() + } + return tl +} + +func getCache[T proto.Message](tl *testLocator) Cache[T] { + return tl.caches[TypeOf[T]().URL()].(Cache[T]) +} + +func TestEndToEnd(t *testing.T) { + locator := newTestLocator( + t, + &ads.Node{ + Id: "diderot-test", + UserAgentName: "gRPC Go", + UserAgentVersionType: &core.Node_UserAgentVersion{UserAgentVersion: grpc.Version}, + ClientFeatures: []string{ + "envoy.lb.does_not_support_overprovisioning", + "xds.config.resource-in-sotw", + }, + }, + TypeOf[*ads.Endpoint](), + TypeOf[*ads.Cluster](), + TypeOf[*ads.Route](), + TypeOf[*ads.Listener](), + TypeOf[*wrapperspb.BytesValue](), + ) + + endpointCache := getCache[*ads.Endpoint](locator) + listenerCache := getCache[*ads.Listener](locator) + bytesCache := getCache[*wrapperspb.BytesValue](locator) + + ts := testutils.NewTestGRPCServer(t) + + resources := readResourcesFromJSONFile(t, "test_xds_config.json") + require.Len(t, resources, 3) + + for _, r := range resources { + c, ok := locator.caches[r.Resource.TypeUrl] + require.Truef(t, ok, "Unknown type loaded from test config %q: %+v", r.Resource.TypeUrl, r) + require.NoError(t, c.SetRaw(r, time.Now())) + } + + addr := ts.Addr().(*net.TCPAddr) + endpointCache.Set( + "testADSServer", + "0", + &ads.Endpoint{ + ClusterName: "testADSServer", + Endpoints: []*endpoint.LocalityLbEndpoints{{ + Locality: new(core.Locality), + LoadBalancingWeight: wrapperspb.UInt32(1), + LbEndpoints: []*endpoint.LbEndpoint{{ + HostIdentifier: &endpoint.LbEndpoint_Endpoint{ + Endpoint: &endpoint.Endpoint{ + Address: &core.Address{ + Address: &core.Address_SocketAddress{ + SocketAddress: &core.SocketAddress{ + Protocol: core.SocketAddress_TCP, + Address: addr.IP.String(), + PortSpecifier: &core.SocketAddress_PortValue{PortValue: uint32(addr.Port)}, + }, + }, + }, + Hostname: "localhost", + }, + }, + }}, + }}, + }, + time.Now(), + ) + + statsHandler := new(serverStatsHandler) + + s := NewADSServer( + locator, + WithGranularResponseRateLimit(0), + WithGlobalResponseRateLimit(0), + WithServerStatsHandler(statsHandler), + WithControlPlane(controlPlane), + ) + discovery.RegisterAggregatedDiscoveryServiceServer(ts.Server, s) + ts.Start() + + xdsResolverBuilder, err := xds.NewXDSResolverWithConfigForTesting([]byte(`{ + "xds_servers": [ + { + "server_uri": "` + ts.AddrString() + `", + "channel_creds": [{"type": "insecure"}], + "server_features": ["xds_v3"] + } + ], + "node": { "id": "diderot-test" } + }`)) + require.NoError(t, err) + + conn, err := grpc.NewClient( + "xds:///testADSServer", + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithResolvers(xdsResolverBuilder), + ) + require.NoError(t, err) + + client := discovery.NewAggregatedDiscoveryServiceClient(conn) + t.Run("xDS sanity check", func(t *testing.T) { + // It's expected that the xDS client will have ACKed the responses it received from the server during + // the LDS -> RDS -> CDS -> EDS flow. However, since grpc.NewClient does not actually attempt to + // establish a connection until the very last moment, testing this actually requires opening a + // stream. This test opens a stream, checks the ACK counter then closes it. + + stream, err := client.DeltaAggregatedResources(testutils.ContextWithTimeout(t, 5*time.Second)) + require.NoError(t, err) + + require.Equal(t, int64(4), statsHandler.ACKsReceived.Load()) + + require.NoError(t, stream.CloseSend()) + }) + + testData := &wrapperspb.BytesValue{Value: make([]byte, 20)} + _, err = rand.Read(testData.Value) + require.NoError(t, err) + testResource := ads.NewResource("testData", "0", testData) + clearEntry := func() { + bytesCache.Clear(testResource.Name, time.Now()) + } + setEntry := func(t *testing.T) { + bytesCache.SetResource(testResource, time.Now()) + t.Cleanup(clearEntry) + } + + t.Run("delta", func(t *testing.T) { + statsHandler.reset() + + stream, err := client.DeltaAggregatedResources(testutils.ContextWithTimeout(t, 5*time.Second)) + require.NoError(t, err) + + req := &ads.DeltaDiscoveryRequest{ + Node: locator.node, + TypeUrl: testResource.TypeURL(), + ResourceNamesSubscribe: []string{testResource.Name}, + } + + require.NoError(t, stream.Send(req)) + + res := new(ads.DeltaDiscoveryResponse) + waitForResponse(t, res, stream, 10*time.Millisecond) + + require.Equal(t, res.RemovedResources, []string{testResource.Name}) + require.Equal(t, int64(1), statsHandler.UnknownResources.Load()) + + setEntry(t) + + waitForResponse(t, res, stream, 10*time.Millisecond) + + require.Len(t, res.Resources, 1) + testutils.ProtoEquals(t, testutils.MustMarshal(t, testResource), res.Resources[0]) + + // check that re-subscribing to a resource causes the server to resend it. + require.NoError(t, stream.Send(req)) + + waitForResponse(t, res, stream, 10*time.Millisecond) + + require.Len(t, res.Resources, 1) + testutils.ProtoEquals(t, testutils.MustMarshal(t, testResource), res.Resources[0]) + + req.ResourceNamesSubscribe = nil + req.ResponseNonce = res.Nonce + require.NoError(t, stream.Send(req)) + + // It's hard to test for the _absence_ of a response to the ACK, however the followup check for the + // removed resource will fail if the server responds to the ACK with anything unexpected. The test + // can still be forced to fail early by checking the value of the ACK metric. + require.Eventually(t, func() bool { + return statsHandler.ACKsReceived.Load() == 1 + }, 2*time.Second, 100*time.Millisecond) + + clearEntry() + + waitForResponse(t, res, stream, 10*time.Millisecond) + + require.Len(t, res.Resources, 0) + require.Equal(t, res.RemovedResources, []string{testResource.Name}) + + // However, the server should respect subscription changes in an ACK. By subscribing to a resource + // that does not exist, it can be forced to respond with a deletion. + req.ResourceNamesSubscribe = []string{"noSuchResource"} + req.ResponseNonce = res.Nonce + require.NoError(t, stream.Send(req)) + + waitForResponse(t, res, stream, 10*time.Millisecond) + + require.Len(t, res.Resources, 0) + require.Equal(t, res.RemovedResources, req.ResourceNamesSubscribe) + require.Equal(t, int64(2), statsHandler.ACKsReceived.Load()) + + // Finally, the NACK metric can be tested by NACKing the previous response. No response is expected + // from the server for this NACK, so the NACK metric needs to be checked. + req.ResourceNamesSubscribe = nil + req.ResponseNonce = res.Nonce + req.ErrorDetail = &status.Status{ + Code: 420, + Message: "Testing NACK", + } + require.NoError(t, stream.Send(req)) + + require.Eventually(t, func() bool { + return statsHandler.NACKsReceived.Load() == 1 + }, 2*time.Second, 100*time.Millisecond) + + require.NoError(t, stream.Send(&ads.DeltaDiscoveryRequest{ + Node: new(core.Node), + TypeUrl: badTypeURL, + ResourceNamesSubscribe: badResources, + })) + + waitForResponse(t, res, stream, 10*time.Millisecond) + + require.Equal(t, badTypeURL, res.GetTypeUrl()) + require.Equal(t, badResources, res.GetRemovedResources()) + require.Equal(t, int64(1), statsHandler.UnknownTypes.Load()) + }) + + t.Run("SotW", func(t *testing.T) { + statsHandler.reset() + + testListener1 := ads.NewResource("testListener1", "0", &ads.Listener{Name: "testListener1"}) + testListener2 := ads.NewResource("testListener2", "1", &ads.Listener{Name: "testListener2"}) + // This test relies on Listener not being a pseudo delta resource type, so fail the test early otherwise + require.False(t, utils.IsPseudoDeltaSotW(testListener1.TypeURL())) + + stream, err := client.StreamAggregatedResources(testutils.Context(t)) + require.NoError(t, err) + + req := &ads.SotWDiscoveryRequest{ + Node: locator.node, + TypeUrl: testListener1.TypeURL(), + ResourceNames: []string{testListener1.Name, testListener2.Name}, + } + require.NoError(t, stream.Send(req)) + + res := new(ads.SotWDiscoveryResponse) + waitForResponse(t, res, stream, 10*time.Millisecond) + require.Empty(t, res.Resources) + require.Equal(t, int64(2), statsHandler.UnknownResources.Load()) + + listenerCache.SetResource(testListener1, time.Now()) + t.Cleanup(func() { + listenerCache.Clear(testListener1.Name, time.Now()) + }) + waitForResponse(t, res, stream, 10*time.Millisecond) + if len(res.Resources) == 0 { + // This can happen because the responses from the cache notifying the client that testListener1 and + // testListener2 arrive asynchronously, so the server may have sent the response for testListener1 + // not being present before receiving the notification for testListener2. This is simply a property + // of SotW, and it's hard to work around. + waitForResponse(t, res, stream, 10*time.Millisecond) + } + require.Len(t, res.Resources, 1) + testutils.ProtoEquals(t, testutils.MustMarshal(t, testListener1).Resource, res.Resources[0]) + + listenerCache.SetResource(testListener2, time.Now()) + waitForResponse(t, res, stream, 10*time.Millisecond) + require.Len(t, res.Resources, 2) + // Order is not guaranteed, so it must be checked explicitly + if proto.Equal(testutils.MustMarshal(t, testListener1).Resource, res.Resources[0]) { + testutils.ProtoEquals(t, testutils.MustMarshal(t, testListener2).Resource, res.Resources[1]) + } else { + testutils.ProtoEquals(t, testutils.MustMarshal(t, testListener1).Resource, res.Resources[1]) + testutils.ProtoEquals(t, testutils.MustMarshal(t, testListener2).Resource, res.Resources[0]) + } + + req.VersionInfo = res.VersionInfo + req.ResponseNonce = res.Nonce + require.NoError(t, stream.Send(req)) + + // It's hard to test for the _absence_ of a response to the ACK, however the followup check for the + // removed resource will fail if the server responds to the ACK with anything unexpected. The test + // can still be forced to fail early by checking the value of the ACK metric. + require.Eventually(t, func() bool { + return statsHandler.ACKsReceived.Load() == 1 + }, 2*time.Second, 100*time.Millisecond) + + listenerCache.Clear(testListener2.Name, time.Now()) + waitForResponse(t, res, stream, 10*time.Millisecond) + require.Len(t, res.Resources, 1) + testutils.ProtoEquals(t, testutils.MustMarshal(t, testListener1).Resource, res.Resources[0]) + + // However, the server is supposed to respect any changes to subscriptions when ACKing, so ACKing the + // most recent response with a different subscription list, namely by adding a resource that does not + // exist (testListener3 in this case), the server should send back a response indicating that it does + // not exist (which in SotW means sending only testListener1). + req.VersionInfo = res.VersionInfo + req.ResponseNonce = res.Nonce + req.ResourceNames = append(req.ResourceNames, "testListener3") + require.NoError(t, stream.Send(req)) + + waitForResponse(t, res, stream, 10*time.Millisecond) + require.Len(t, res.Resources, 1) + testutils.ProtoEquals(t, testutils.MustMarshal(t, testListener1).Resource, res.Resources[0]) + require.Equal(t, int64(2), statsHandler.ACKsReceived.Load()) + + // Finally, check the NACK logic. Note that even though the subscription list hasn't changed, no + // response is expected. + req.VersionInfo = res.VersionInfo + req.ResponseNonce = res.Nonce + req.ErrorDetail = &status.Status{ + Code: 420, + Message: "Testing NACK", + } + require.NoError(t, stream.Send(req)) + + require.Eventually(t, func() bool { + return statsHandler.NACKsReceived.Load() == 1 + }, 2*time.Second, 100*time.Millisecond) + + require.NoError(t, stream.Send(&ads.SotWDiscoveryRequest{ + Node: new(core.Node), + TypeUrl: badTypeURL, + ResourceNames: badResources, + })) + + waitForResponse(t, res, stream, 10*time.Millisecond) + + require.Equal(t, badTypeURL, res.GetTypeUrl()) + require.Empty(t, res.GetResources()) + require.Equal(t, int64(1), statsHandler.UnknownTypes.Load()) + }) + + // Author's note: there are no semantic differences in the way subscriptions and ACKs are managed for + // pseudo delta SotW, so this test avoids retesting what was already tested in the previous test for + // brevity. + t.Run("PseudoDeltaSotW", func(t *testing.T) { + statsHandler.reset() + + // This test relies on Bytes being a pseudo delta resource type, so fail the test early otherwise + require.True(t, utils.IsPseudoDeltaSotW(testResource.TypeURL())) + + stream, err := client.StreamAggregatedResources(testutils.Context(t)) + require.NoError(t, err) + + req := &ads.SotWDiscoveryRequest{ + Node: locator.node, + TypeUrl: testResource.TypeURL(), + ResourceNames: []string{testResource.Name}, + } + require.NoError(t, stream.Send(req)) + + // PseudoDeltaSotW does not provide a mechanism for the server to communicate that a resource does not exist. + // Functionally, PseudoDeltaSotW clients are expected to treat requested resources that don't arrive within a + // given timeout to be deleted. Here we check that the server does not respond with the requested resource + // before it's set by using a wait then checking the time at which the resource was received after it was set. + const wait = time.Second + time.AfterFunc(wait, func() { + // The metric should have been updated at this point, if not, fail the test + require.Equal(t, int64(1), statsHandler.UnknownResources.Load()) + setEntry(t) + }) + const delta = 50 * time.Millisecond + + startWait := time.Now() + res := new(ads.SotWDiscoveryResponse) + waitForResponse(t, res, stream, wait+delta) + require.WithinDuration(t, time.Now(), startWait.Add(wait), delta) + require.Len(t, res.Resources, 1) + testutils.ProtoEquals(t, testutils.MustMarshal(t, testResource).Resource, res.Resources[0]) + + // Note that the following is technically a protocol violation. As noted above, PseudoDeltaSotW + // cannot signal to the client that a resource does not exist, it simply never responds. However, in + // the event that the server receives a request for a type it does not know (and never will since + // they are not dynamic and determined at startup), since there is no way to signal that this request + // will never be satisfied, the server will respond with an empty response. + require.NoError(t, stream.Send(&ads.SotWDiscoveryRequest{ + Node: new(core.Node), + TypeUrl: badTypeURL, + ResourceNames: badResources, + })) + waitForResponse(t, res, stream, wait+delta) + require.Equal(t, badTypeURL, res.GetTypeUrl(), prototext.Format(res)) + require.Empty(t, res.GetResources()) + require.Equal(t, int64(1), statsHandler.UnknownTypes.Load()) + }) + +} + +type xDSResponse interface { + proto.Message + GetControlPlane() *core.ControlPlane +} + +// waitForResponse waits for a response on the given stream, failing the test if the response does +// not arrive within the timeout or if an error is returned. +func waitForResponse( + t *testing.T, + res xDSResponse, + stream interface{ RecvMsg(any) error }, + timeout time.Duration, +) { + t.Helper() + + ch := make(chan error) + go func() { + ch <- stream.RecvMsg(res) + }() + + select { + case err := <-ch: + require.NoError(t, err) + case <-time.After(timeout): + t.Fatalf("Did not receive response in %s", timeout) + } + testutils.ProtoEquals(t, controlPlane, res.GetControlPlane()) +} + +func readResourcesFromJSONFile(t *testing.T, f string) (resources []*ads.RawResource) { + data, err := os.ReadFile(f) + require.NoError(t, err) + + var rawResources []json.RawMessage + require.NoError(t, json.Unmarshal(data, &rawResources)) + for _, raw := range rawResources { + r := new(ads.RawResource) + require.NoError(t, protojson.Unmarshal(raw, r), string(raw)) + resources = append(resources, r) + } + return resources +} + +type simpleBatchHandler struct { + t *testing.T + notify func(name string, r *ads.RawResource, metadata ads.SubscriptionMetadata) + ch atomic.Pointer[chan struct{}] +} + +func (h *simpleBatchHandler) StartNotificationBatch() { + ch := make(chan struct{}, 1) + require.True(h.t, h.ch.CompareAndSwap(nil, &ch)) +} + +func (h *simpleBatchHandler) Notify(name string, r *ads.RawResource, metadata ads.SubscriptionMetadata) { + h.notify(name, r, metadata) +} + +func (h *simpleBatchHandler) ResourceMarshalError(name string, resource proto.Message, err error) { + h.t.Fatalf("Unexpected resource marshal error for %q: %v\n%v", name, err, resource) +} + +func (h *simpleBatchHandler) EndNotificationBatch() { + close(*h.ch.Load()) +} + +func (h *simpleBatchHandler) check() { + <-*h.ch.Swap(nil) +} + +func newSotWReq(subscribe ...string) *ads.SotWDiscoveryRequest { + return &ads.SotWDiscoveryRequest{ + ResourceNames: subscribe, + } +} + +func newDeltaReq(subscribe, unsubscribe []string) *ads.DeltaDiscoveryRequest { + return &ads.DeltaDiscoveryRequest{ + ResourceNamesSubscribe: subscribe, + ResourceNamesUnsubscribe: unsubscribe, + } +} + +func TestSubscriptionManagerSubscriptions(t *testing.T) { + const ( + r1 = "r1" + r2 = "r2" + ) + checkSubs := func(t *testing.T, c RawCache, h ads.RawSubscriptionHandler, wildcard, r1Sub, r2Sub bool) { + t.Helper() + require.Equal(t, wildcard, IsSubscribedTo(c, ads.WildcardSubscription, h), "wildcard") + require.Equal(t, r1Sub, IsSubscribedTo(c, r1, h), r1) + require.Equal(t, r2Sub, IsSubscribedTo(c, r2, h), r2) + } + + newCacheAndHandler := func(t *testing.T) (Cache[*wrapperspb.BoolValue], ResourceLocator, *simpleBatchHandler) { + tl := newTestLocator(t, nil, TypeOf[*wrapperspb.BoolValue]()) + c := getCache[*wrapperspb.BoolValue](tl) + expected := ads.NewResource(r1, "0", wrapperspb.Bool(true)) + c.SetResource(expected, time.Time{}) + + h := &simpleBatchHandler{ + t: t, + notify: func(name string, r *ads.RawResource, metadata ads.SubscriptionMetadata) { + if name == r1 { + require.Same(t, testutils.MustMarshal(t, expected), r) + b, err := r.Resource.UnmarshalNew() + require.NoError(t, err) + testutils.ProtoEquals(t, wrapperspb.Bool(true), b) + } else { + require.Nil(t, r) + } + }, + } + + return c, tl, h + } + + for _, streamType := range []ads.StreamType{ads.DeltaStreamType, ads.SotWStreamType} { + t.Run(streamType.String(), func(t *testing.T) { + t.Run("wildcard", func(t *testing.T) { + c, l, h := newCacheAndHandler(t) + + var sotw internal.SubscriptionManager[*ads.SotWDiscoveryRequest] + var delta internal.SubscriptionManager[*ads.DeltaDiscoveryRequest] + if streamType == ads.DeltaStreamType { + delta = internal.NewDeltaSubscriptionManager(testutils.Context(t), l, c.Type().URL(), h) + } else { + sotw = internal.NewSotWSubscriptionManager(testutils.Context(t), l, c.Type().URL(), h) + } + + checkSubs(t, c, h, false, false, false) + + // subscribe to the wildcard + if streamType == ads.DeltaStreamType { + delta.ProcessSubscriptions(newDeltaReq([]string{ads.WildcardSubscription}, nil)) + } else { + sotw.ProcessSubscriptions(newSotWReq(ads.WildcardSubscription)) + } + h.check() + checkSubs(t, c, h, + true, + // implicit subscription to r1 via wildcard + true, + // implicit subscription to r2 via wildcard + true, + ) + + // subscribe to r2, unsubscribe from wildcard + if streamType == ads.DeltaStreamType { + delta.ProcessSubscriptions(newDeltaReq([]string{r2}, []string{ads.WildcardSubscription})) + } else { + sotw.ProcessSubscriptions(newSotWReq(r2)) + } + h.check() + checkSubs(t, c, h, + false, + // because r1 was not explicitly subscribed to, its implicit subscription should also be removed + false, + // explicit subscription + true, + ) + }) + + t.Run("normal", func(t *testing.T) { + c, l, h := newCacheAndHandler(t) + + var sotw internal.SubscriptionManager[*ads.SotWDiscoveryRequest] + var delta internal.SubscriptionManager[*ads.DeltaDiscoveryRequest] + if streamType == ads.DeltaStreamType { + delta = internal.NewDeltaSubscriptionManager(testutils.Context(t), l, c.Type().URL(), h) + } else { + sotw = internal.NewSotWSubscriptionManager(testutils.Context(t), l, c.Type().URL(), h) + } + + // subscribe to r1 and r2 + if streamType == ads.DeltaStreamType { + delta.ProcessSubscriptions(newDeltaReq([]string{r1, r2}, nil)) + } else { + sotw.ProcessSubscriptions(newSotWReq(r1, r2)) + } + h.check() + checkSubs(t, c, h, + false, + true, + true, + ) + + // unsubscribe from r2, keep r1 + if streamType == ads.DeltaStreamType { + delta.ProcessSubscriptions(newDeltaReq(nil, []string{r2})) + } else { + sotw.ProcessSubscriptions(newSotWReq(r1)) + } + h.check() + checkSubs(t, c, h, + false, + true, + // unsubscribed + false, + ) + }) + }) + } +} + +type mockResourceLocator struct { + isTypeSupported func(typeURL string) bool + subscribe func(typeURL, resourceName string) func() + resubscribe func(typeURL, resourceName string) +} + +func (m *mockResourceLocator) IsTypeSupported(_ context.Context, typeURL string) bool { + return m.isTypeSupported(typeURL) +} + +func (m *mockResourceLocator) Subscribe(_ context.Context, typeURL, resourceName string, _ ads.RawSubscriptionHandler) func() { + return m.subscribe(typeURL, resourceName) +} + +func (m *mockResourceLocator) Resubscribe(_ context.Context, typeURL, resourceName string, _ ads.RawSubscriptionHandler) { + m.resubscribe(typeURL, resourceName) +} + +func TestImplicitWildcardSubscription(t *testing.T) { + const foo = "foo" + h := NewNoopBatchSubscriptionHandler(t) + typeURL := TypeOf[*ads.Secret]().URL() + + newMockLocator := func(t *testing.T) (l *mockResourceLocator, wildcardSub, fooSub chan struct{}) { + wildcardSub = make(chan struct{}, 1) + fooSub = make(chan struct{}, 1) + l = &mockResourceLocator{ + isTypeSupported: func(actualTypeURL string) bool { + require.Equal(t, typeURL, actualTypeURL) + return true + }, + subscribe: func(actualTypeURL, resourceName string) func() { + require.Equal(t, typeURL, actualTypeURL) + switch resourceName { + case ads.WildcardSubscription: + wildcardSub <- struct{}{} + return func() { + close(wildcardSub) + } + case foo: + fooSub <- struct{}{} + return func() { + close(fooSub) + } + default: + t.Fatalf("Unexpected resource name %q", resourceName) + return nil + } + }, + resubscribe: func(actualTypeURL, resourceName string) { + switch resourceName { + case ads.WildcardSubscription: + wildcardSub <- struct{}{} + case foo: + fooSub <- struct{}{} + default: + t.Fatalf("Unexpected resource name %q", resourceName) + } + }, + } + return l, wildcardSub, fooSub + } + requireSelect := func(t *testing.T, ch <-chan struct{}, shouldBeClosed bool) { + t.Helper() + select { + case _, ok := <-ch: + if ok && shouldBeClosed { + t.Fatalf("Channel not closed") + } + if !ok && !shouldBeClosed { + t.Fatalf("Channel unexpectedly closed") + } + default: + t.Fatalf("empty channel!") + } + } + + t.Run("SotW", func(t *testing.T) { + t.Run("empty first call", func(t *testing.T) { + l, wildcardSub, fooSub := newMockLocator(t) + m := internal.NewSotWSubscriptionManager(testutils.Context(t), l, typeURL, h) + + // The first call, if empty should always implicit create a wildcard subscription. + m.ProcessSubscriptions(newSotWReq()) + requireSelect(t, wildcardSub, false) + + // Subsequent requests can ACK the previous wildcard request but not change the subscriptions and not + // provide an explicit resource to subscribe to, in which case the wildcard should persist. + m.ProcessSubscriptions(newSotWReq()) + require.Empty(t, wildcardSub) + + // However once a resource name is explicitly provided, the implicit wildcard should disappear. + m.ProcessSubscriptions(newSotWReq(foo)) + requireSelect(t, wildcardSub, true) + requireSelect(t, fooSub, false) + }) + t.Run("non-empty first call", func(t *testing.T) { + l, wildcardSub, fooSub := newMockLocator(t) + m := internal.NewSotWSubscriptionManager(testutils.Context(t), l, typeURL, h) + + // If the first call isn't empty, the implicit wildcard subscription should not be present. + m.ProcessSubscriptions(newSotWReq(foo)) + requireSelect(t, fooSub, false) + require.Empty(t, wildcardSub) + }) + t.Run("explicit wildcard", func(t *testing.T) { + l, wildcardSub, fooSub := newMockLocator(t) + m := internal.NewSotWSubscriptionManager(testutils.Context(t), l, typeURL, h) + + m.ProcessSubscriptions(newSotWReq(ads.WildcardSubscription)) + requireSelect(t, wildcardSub, false) + require.Empty(t, fooSub) + + m.ProcessSubscriptions(newSotWReq()) + requireSelect(t, wildcardSub, true) + }) + }) + t.Run("Delta", func(t *testing.T) { + t.Run("empty first call", func(t *testing.T) { + l, wildcardSub, fooSub := newMockLocator(t) + m := internal.NewDeltaSubscriptionManager(testutils.Context(t), l, typeURL, h) + + // The first call, if empty should always implicit create a wildcard subscription. + m.ProcessSubscriptions(newDeltaReq(nil, nil)) + requireSelect(t, wildcardSub, false) + + // Subsequent requests can ACK the previous wildcard request but not change the subscriptions and not + // provide an explicit resource to subscribe to, in which case the wildcard should persist. + m.ProcessSubscriptions(newDeltaReq(nil, nil)) + // However, unlike SotW, it should not resubscribe because it was not explicit. + require.Empty(t, wildcardSub) + + // In Delta, the implicit wildcard subscription created by the first message must be explicitly + // removed. + m.ProcessSubscriptions(newDeltaReq([]string{foo}, nil)) + // Since there was no explicit change to the wildcard subscription, no notification is expected + require.Empty(t, wildcardSub) + requireSelect(t, fooSub, false) + + m.ProcessSubscriptions(newDeltaReq(nil, []string{ads.WildcardSubscription})) + require.Empty(t, fooSub) + requireSelect(t, wildcardSub, true) + }) + t.Run("non-empty first call", func(t *testing.T) { + l, wildcardSub, fooSub := newMockLocator(t) + m := internal.NewDeltaSubscriptionManager(testutils.Context(t), l, typeURL, h) + + // If the first call isn't empty, the implicit wildcard subscription should not be present. + m.ProcessSubscriptions(newDeltaReq([]string{foo}, nil)) + require.Empty(t, wildcardSub) + requireSelect(t, fooSub, false) + }) + t.Run("explicit wildcard", func(t *testing.T) { + l, wildcardSub, fooSub := newMockLocator(t) + m := internal.NewDeltaSubscriptionManager(testutils.Context(t), l, typeURL, h) + + m.ProcessSubscriptions(newDeltaReq([]string{ads.WildcardSubscription}, nil)) + requireSelect(t, wildcardSub, false) + require.Empty(t, fooSub) + + m.ProcessSubscriptions(newDeltaReq(nil, []string{ads.WildcardSubscription})) + requireSelect(t, wildcardSub, true) + require.Empty(t, fooSub) + }) + }) +} + +// batchFuncHandler the equivalent of funcHandler but for the BatchSubscriptionHandler interface. +type batchFuncHandler struct { + t *testing.T + start func() + notify func(name string, r *ads.RawResource, metadata ads.SubscriptionMetadata) + end func() +} + +func (b *batchFuncHandler) StartNotificationBatch() { + b.start() +} + +func (b *batchFuncHandler) Notify(name string, r *ads.RawResource, metadata ads.SubscriptionMetadata) { + b.notify(name, r, metadata) +} + +func (b *batchFuncHandler) ResourceMarshalError(name string, resource proto.Message, err error) { + b.t.Fatalf("Unexpected resource marshal error for %q: %v\n%v", name, err, resource) +} + +func (b *batchFuncHandler) EndNotificationBatch() { + b.end() +} + +func NewBatchSubscriptionHandler( + t *testing.T, + start func(), + notify func(name string, r *ads.RawResource, metadata ads.SubscriptionMetadata), + end func(), +) internal.BatchSubscriptionHandler { + return &batchFuncHandler{ + t: t, + start: start, + notify: notify, + end: end, + } +} + +func NewNoopBatchSubscriptionHandler(t *testing.T) internal.BatchSubscriptionHandler { + return NewBatchSubscriptionHandler( + t, + func() {}, func(string, *ads.RawResource, ads.SubscriptionMetadata) {}, func() {}, + ) +} + +func TestSubscriptionManagerUnsubscribeAll(t *testing.T) { + typeURL := TypeOf[*ads.Secret]().URL() + h := NewNoopBatchSubscriptionHandler(t) + + t.Run("explicit", func(t *testing.T) { + const foo = "foo" + + var wg sync.WaitGroup + + l := &mockResourceLocator{ + isTypeSupported: func(string) bool { return true }, + subscribe: func(_, resourceName string) func() { + wg.Done() + return func() { + wg.Done() + } + }, + } + + m := internal.NewDeltaSubscriptionManager(context.Background(), l, typeURL, h) + + wg.Add(2) + m.ProcessSubscriptions(&ads.DeltaDiscoveryRequest{ + ResourceNamesSubscribe: []string{ads.WildcardSubscription, foo}, + }) + wg.Wait() + + wg.Add(2) + m.UnsubscribeAll() + wg.Wait() + }) + + t.Run("on context expiry", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + l := &mockResourceLocator{ + isTypeSupported: func(string) bool { return true }, + subscribe: func(_, _ string) func() { + wg.Done() + return func() { + wg.Done() + } + }, + } + m := internal.NewDeltaSubscriptionManager(ctx, l, typeURL, h) + + wg.Add(1) + m.ProcessSubscriptions(&ads.DeltaDiscoveryRequest{ + ResourceNamesSubscribe: []string{ads.WildcardSubscription}, + }) + wg.Wait() + + wg.Add(1) + cancel() + wg.Wait() + }) + +} diff --git a/stats/server/server_stats.go b/stats/server/server_stats.go new file mode 100644 index 0000000..751d0d5 --- /dev/null +++ b/stats/server/server_stats.go @@ -0,0 +1,99 @@ +package serverstats + +import ( + "context" + "time" + + "github.com/linkedin/diderot/ads" + "google.golang.org/protobuf/proto" +) + +// Handler will be invoked with an event of the corresponding type when said event occurs. +type Handler interface { + HandleServerEvent(context.Context, Event) +} + +// Event contains information about a specific event that happened in the server. +type Event interface { + isServerEvent() +} + +// RequestReceived contains the stats of a request received by the server. +type RequestReceived struct { + // The received request, either [ads.SotWDiscoveryRequest] or [ads.DeltaDiscoveryRequest]. + Req proto.Message + // True if the client requested a type that is not supported, determined by the ResourceLocator. + IsRequestedTypeUnknown bool + // Whether the request is an ACK + IsACK bool + // Whether the request is a NACK. Note that this is an important stat that requires immediate human + // intervention. + IsNACK bool + // The given duration represents the time it took to handle the request, i.e. validating it and + // processing its subscriptions if necessary. It does not include the time for any of the + // resources to be sent in a response. + Duration time.Duration +} + +func (s *RequestReceived) isServerEvent() {} + +// ResponseSent contains the stats of a response sent by the server. +type ResponseSent struct { + // The response sent, either [ads.SotWDiscoveryResponse] or [ads.DeltaDiscoveryResponse]. + Res proto.Message + // How long the Send operation took. This includes any time added by flow-control. + Duration time.Duration +} + +func (s *ResponseSent) isServerEvent() {} + +// TimeInGlobalRateLimiter contains the stats of the time spent in the global rate limiter. +type TimeInGlobalRateLimiter struct { + // How long the server waited for the global rate limiter to clear. + Duration time.Duration +} + +func (s *TimeInGlobalRateLimiter) isServerEvent() {} + +// ResourceMarshalError contains the stats for a resource that could not be marshaled. This +// should be extremely rare and requires immediate attention. +type ResourceMarshalError struct { + // The name of the resource that could not be marshaled. + ResourceName string + // The resource that could not be marshaled. + Resource proto.Message + // The marshaling error. + Err error +} + +func (s *ResourceMarshalError) isServerEvent() {} + +// ResourceOverMaxSize contains the stats for a critical error that signals a resource will +// never be received by clients that are subscribed to it. It likely requires immediate human +// intervention. +type ResourceOverMaxSize struct { + // The resource that could not be sent. + Resource *ads.RawResource + // The encoded resource size. + ResourceSize int + // The maximum resource size (usually 4MB, gRPC's default max message size). + MaxResourceSize int +} + +func (s *ResourceOverMaxSize) isServerEvent() {} + +// ResourceQueued contains the stats for a resource entering the send queue. +type ResourceQueued struct { + // The name of the resource + ResourceName string + // The resource itself, nil if the resource is being deleted. + Resource *ads.RawResource + // The metadata for the resource and subscription. + Metadata ads.SubscriptionMetadata + // Indicates whether the resource existed at all and is being deleted, or whether the client + // subscribed to a resource that never existed. This should be rare, and can be indicative of a + // client-side bug. + ResourceExists bool +} + +func (s *ResourceQueued) isServerEvent() {} diff --git a/test_xds_config.json b/test_xds_config.json new file mode 100644 index 0000000..85cfaf5 --- /dev/null +++ b/test_xds_config.json @@ -0,0 +1,72 @@ +[ + { + "name": "testADSServer", + "version": "1", + "resource": { + "@type": "type.googleapis.com/envoy.config.listener.v3.Listener", + "name": "testADSServer", + "apiListener": { + "apiListener": { + "@type": "type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager", + "rds": { + "configSource": { + "ads": {}, + "resourceApiVersion": "V3" + }, + "routeConfigName": "testADSServer" + }, + "httpFilters": [ + { + "name": "default", + "typedConfig": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.router.v3.Router" + } + } + ] + } + } + } + }, + { + "name": "testADSServer", + "version": "1", + "resource": { + "@type": "type.googleapis.com/envoy.config.route.v3.RouteConfiguration", + "name": "testADSServer", + "virtualHosts": [ + { + "name": "testADSServer", + "domains": [ + "*" + ], + "routes": [ + { + "name": "default", + "match": { + "prefix": "" + }, + "route": { + "cluster": "testADSServer" + } + } + ] + } + ] + } + }, + { + "name": "testADSServer", + "version": "1", + "resource": { + "@type": "type.googleapis.com/envoy.config.cluster.v3.Cluster", + "name": "testADSServer", + "type": "EDS", + "edsClusterConfig": { + "edsConfig": { + "ads": {}, + "resourceApiVersion": "V3" + } + } + } + } +] \ No newline at end of file diff --git a/testutils/testutils.go b/testutils/testutils.go new file mode 100644 index 0000000..d3ee2c6 --- /dev/null +++ b/testutils/testutils.go @@ -0,0 +1,266 @@ +package testutils + +import ( + "context" + "maps" + "net" + "slices" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/linkedin/diderot/ads" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/proto" +) + +func WithTimeout(t *testing.T, name string, timeout time.Duration, f func(t *testing.T)) { + t.Run(name, func(t *testing.T) { + t.Helper() + done := make(chan struct{}) + go func() { + f(t) + close(done) + }() + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case <-timer.C: + t.Fatalf("%q failed to complete in %s", t.Name(), timeout) + case <-done: + return + } + }) +} + +func Context(tb testing.TB) context.Context { + ctx, cancel := context.WithCancel(context.Background()) + tb.Cleanup(cancel) + return ctx +} + +func ContextWithTimeout(tb testing.TB, timeout time.Duration) context.Context { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + tb.Cleanup(cancel) + return ctx +} + +type Notification[T proto.Message] struct { + Name string + Resource *ads.Resource[T] + Metadata ads.SubscriptionMetadata +} + +type ChanSubscriptionHandler[T proto.Message] chan Notification[T] + +func (c ChanSubscriptionHandler[T]) Notify(name string, r *ads.Resource[T], metadata ads.SubscriptionMetadata) { + c <- Notification[T]{ + Name: name, + Resource: r, + Metadata: metadata, + } +} + +// This is the bare minimum required by the testify framework. *testing.T implements it, but this +// interface is used for testing the test framework. +type testingT interface { + Logf(format string, args ...any) + Errorf(format string, args ...any) + FailNow() + Helper() + Fatalf(string, ...any) +} + +var _ testingT = (*testing.T)(nil) +var _ testingT = (*testing.B)(nil) + +type ExpectedNotification[T proto.Message] struct { + Name string + Resource *ads.Resource[T] +} + +func ExpectDelete[T proto.Message](name string) ExpectedNotification[T] { + return ExpectedNotification[T]{Name: name} +} + +func ExpectUpdate[T proto.Message](r *ads.Resource[T]) ExpectedNotification[T] { + return ExpectedNotification[T]{Name: r.Name, Resource: r} +} + +func (c ChanSubscriptionHandler[T]) WaitForDelete( + t testingT, + expectedName string, +) Notification[T] { + t.Helper() + return c.WaitForNotifications(t, ExpectDelete[T](expectedName))[0] +} + +func (c ChanSubscriptionHandler[T]) WaitForUpdate(t testingT, r *ads.Resource[T]) Notification[T] { + t.Helper() + return c.WaitForNotifications(t, ExpectUpdate(r))[0] +} + +func (c ChanSubscriptionHandler[T]) WaitForNotifications(t testingT, notifications ...ExpectedNotification[T]) (out []Notification[T]) { + t.Helper() + + expectedNotifications := make(map[string]int) + for i, n := range notifications { + expectedNotifications[n.Name] = i + } + + out = make([]Notification[T], len(notifications)) + + for range notifications { + var n Notification[T] + select { + case n = <-c: + case <-time.After(5 * time.Second): + t.Fatalf("Did not receive expected notification for one of: %v", + slices.Collect(maps.Keys(expectedNotifications))) + } + + idx, ok := expectedNotifications[n.Name] + if !ok { + require.Fail(t, "Received unexpected notification", n.Name) + } + expected := notifications[idx] + out[idx] = n + delete(expectedNotifications, n.Name) + + if expected.Resource != nil { + require.NotNilf(t, n.Resource, "Expected update for %q, got deletion instead", expected.Name) + ResourceEquals(t, expected.Resource, n.Resource) + } else { + require.Nilf(t, n.Resource, "Expected delete for %q, got update instead", expected.Name) + } + } + + require.Empty(t, expectedNotifications) + + return out +} + +func ResourceEquals[T proto.Message](t testingT, expected, actual *ads.Resource[T]) { + t.Helper() + require.Equal(t, expected.Name, actual.Name) + require.Equal(t, expected.Version, actual.Version) + ProtoEquals(t, expected.Resource, actual.Resource) + ProtoEquals(t, expected.Ttl, actual.Ttl) + ProtoEquals(t, expected.CacheControl, actual.CacheControl) + ProtoEquals(t, expected.Metadata, actual.Metadata) +} + +func ProtoEquals(t testingT, expected, actual proto.Message) { + t.Helper() + if !proto.Equal(expected, actual) { + t.Fatalf( + "Messages not equal:\nexpected:%s\nactual :%s\n%s", + expected, actual, + cmp.Diff(prototext.Format(expected), prototext.Format(actual)), + ) + } +} + +// FuncHandler is a SubscriptionHandler implementation that simply invokes a function. Note that the usual pattern of +// having a literal func type implement the interface (e.g. http.HandlerFunc) does not work in this case because funcs +// are not hashable and therefore cannot be used as map keys, which is often how SubscriptionHandlers are used. +type FuncHandler[T proto.Message] struct { + notify func(name string, r *ads.Resource[T], metadata ads.SubscriptionMetadata) +} + +func (f *FuncHandler[T]) Notify(name string, r *ads.Resource[T], metadata ads.SubscriptionMetadata) { + f.notify(name, r, metadata) +} + +// NewSubscriptionHandler returns a SubscriptionHandler that invokes the given function when +// SubscriptionHandler.Notify is invoked. +func NewSubscriptionHandler[T proto.Message]( + notify func(name string, r *ads.Resource[T], metadata ads.SubscriptionMetadata), +) *FuncHandler[T] { + return &FuncHandler[T]{ + notify: notify, + } +} + +type RawFuncHandler struct { + t testingT + notify func(name string, r *ads.RawResource, metadata ads.SubscriptionMetadata) +} + +func (r *RawFuncHandler) Notify(name string, raw *ads.RawResource, metadata ads.SubscriptionMetadata) { + r.notify(name, raw, metadata) +} + +func (r *RawFuncHandler) ResourceMarshalError(name string, resource proto.Message, err error) { + r.t.Fatalf("Unexpected resource marshal error for %q: %v\n%v", name, err, resource) +} + +// NewRawSubscriptionHandler returns a RawSubscriptionHandler that invokes the given function when +// SubscriptionHandler.Notify is invoked. +func NewRawSubscriptionHandler( + t testingT, + notify func(name string, r *ads.RawResource, metadata ads.SubscriptionMetadata), +) *RawFuncHandler { + return &RawFuncHandler{t: t, notify: notify} +} + +// TestServer is instantiated with NewTestGRPCServer and serves to facilitate local testing against +// gRPC service implementations. +type TestServer struct { + t *testing.T + *grpc.Server + net.Listener +} + +// Start starts the backing gRPC server in a goroutine. Must be invoked _after_ registering the services. +func (ts *TestServer) Start() { + go func() { + require.NoError(ts.t, ts.Server.Serve(ts.Listener)) + }() +} + +// Dial invokes DialContext with the given options and a context generated using Context. +func (ts *TestServer) Dial(opts ...grpc.DialOption) *grpc.ClientConn { + opts = append([]grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}, opts...) + conn, err := grpc.NewClient(ts.AddrString(), opts...) + require.NoError(ts.t, err) + return conn +} + +func (ts *TestServer) AddrString() string { + return ts.Addr().String() +} + +// NewTestGRPCServer is a utility function that spins up a TCP listener on a random local port along +// with a grpc.Server. It cleans up any associated state using the Cleanup methods. Sample usage is +// as follows: +// +// ts := NewTestGRPCServer(t) +// discovery.RegisterAggregatedDiscoveryServiceServer(ts.Server, s) +// ts.Start() +// conn := ts.Dial() +func NewTestGRPCServer(t *testing.T, opts ...grpc.ServerOption) *TestServer { + ts := &TestServer{ + t: t, + Server: grpc.NewServer(opts...), + } + + var err error + ts.Listener, err = net.Listen("tcp", "localhost:0") + require.NoError(t, err) + + t.Cleanup(func() { + ts.Server.Stop() + }) + + return ts +} + +func MustMarshal[T proto.Message](t testingT, r *ads.Resource[T]) *ads.RawResource { + marshaled, err := r.Marshal() + require.NoError(t, err) + return marshaled +} diff --git a/testutils/testutils_test.go b/testutils/testutils_test.go new file mode 100644 index 0000000..3bdb941 --- /dev/null +++ b/testutils/testutils_test.go @@ -0,0 +1,118 @@ +package testutils + +import ( + "testing" + "time" + + "github.com/linkedin/diderot/ads" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +var failNowInvoked = new(byte) + +type testingTMock testing.T + +func (t *testingTMock) Errorf(format string, args ...any) { + (*testing.T)(t).Logf(format, args...) +} + +func (t *testingTMock) Fatalf(format string, args ...any) { + (*testing.T)(t).Logf(format, args...) + t.FailNow() +} + +func (t *testingTMock) FailNow() { + panic(failNowInvoked) +} + +func (t *testingTMock) Helper() { +} + +func TestChanSubscriptionHandler_WaitForNotification(t *testing.T) { + const foo = "foo" + + expected := wrapperspb.Int64(42) + version := "0" + resource := &ads.Resource[*wrapperspb.Int64Value]{ + Name: foo, + Version: version, + Resource: expected, + } + + tests := []struct { + name string + shouldFail bool + test func(mock *testingTMock, h ChanSubscriptionHandler[*wrapperspb.Int64Value]) + }{ + { + name: "receive different name", + shouldFail: true, + test: func(mock *testingTMock, h ChanSubscriptionHandler[*wrapperspb.Int64Value]) { + metadata := ads.SubscriptionMetadata{SubscribedAt: time.Now()} + h.Notify("bar", nil, metadata) + h.WaitForDelete(mock, foo) + }, + }, + { + name: "expect delete", + shouldFail: false, + test: func(mock *testingTMock, h ChanSubscriptionHandler[*wrapperspb.Int64Value]) { + metadata := ads.SubscriptionMetadata{SubscribedAt: time.Now()} + h.Notify(foo, nil, metadata) + h.WaitForDelete(mock, foo) + }, + }, + { + name: "expect delete, get update", + shouldFail: true, + test: func(mock *testingTMock, h ChanSubscriptionHandler[*wrapperspb.Int64Value]) { + metadata := ads.SubscriptionMetadata{SubscribedAt: time.Now()} + h.Notify(foo, resource, metadata) + h.WaitForDelete(mock, foo) + }, + }, + { + name: "expect update", + shouldFail: false, + test: func(mock *testingTMock, h ChanSubscriptionHandler[*wrapperspb.Int64Value]) { + metadata := ads.SubscriptionMetadata{SubscribedAt: time.Now()} + h.Notify(foo, resource, metadata) + h.WaitForUpdate(mock, resource) + }, + }, + { + name: "expect update, get delete", + shouldFail: true, + test: func(mock *testingTMock, h ChanSubscriptionHandler[*wrapperspb.Int64Value]) { + metadata := ads.SubscriptionMetadata{SubscribedAt: time.Now()} + h.Notify(foo, nil, metadata) + h.WaitForUpdate(mock, resource) + }, + }, + { + name: "received different value", + shouldFail: true, + test: func(mock *testingTMock, h ChanSubscriptionHandler[*wrapperspb.Int64Value]) { + metadata := ads.SubscriptionMetadata{SubscribedAt: time.Now()} + h.Notify(foo, resource, metadata) + h.WaitForUpdate(mock, ads.NewResource[*wrapperspb.Int64Value](foo, version, wrapperspb.Int64(27))) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + h := make(ChanSubscriptionHandler[*wrapperspb.Int64Value], 1) + mock := (*testingTMock)(t) + if test.shouldFail { + require.PanicsWithValuef(t, failNowInvoked, func() { + test.test(mock, h) + }, "did not panic!") + } else { + test.test(mock, h) + } + }) + } + +} diff --git a/type.go b/type.go new file mode 100644 index 0000000..c14f934 --- /dev/null +++ b/type.go @@ -0,0 +1,123 @@ +package diderot + +import ( + "github.com/linkedin/diderot/ads" + "github.com/linkedin/diderot/internal/utils" + "google.golang.org/protobuf/proto" +) + +// typeReference is the only implementation of the Type and, by extension, the TypeReference +// interface. It is not exposed publicly to ensure that all instances are generated through TypeOf, +// which uses reflection on the type parameter to determine the type URL. This is to avoid potential +// runtime complications due to invalid type URL strings. +type typeReference[T proto.Message] string + +// TypeReference is a superset of the Type interface which captures the actual runtime type. +type TypeReference[T proto.Message] interface { + Type +} + +// Type is a type reference for a type that can be cached. Only accessible through TypeOf. +type Type interface { + // URL returns the type URL for this Type. + URL() string + // TrimmedURL returns the type URL for this Type without the leading "types.googleapis.com/" prefix. + // This string is useful when constructing xdstp URLs. + TrimmedURL() string + // NewCache is the untyped equivalent of this package's NewCache. The returned RawCache still + // retains the runtime type information and can be safely cast to the corresponding Cache type. + NewCache() RawCache + // NewPrioritizedCache is the untyped equivalent of this package's NewPrioritizedCache. The returned + // RawCache instances can be safely cast to the corresponding Cache type. + NewPrioritizedCache(prioritySlots int) []RawCache + + isSubscribedTo(c RawCache, name string, handler ads.RawSubscriptionHandler) bool + subscribe(c RawCache, name string, handler ads.RawSubscriptionHandler) + unsubscribe(c RawCache, name string, handler ads.RawSubscriptionHandler) +} + +func (t typeReference[T]) URL() string { + return string(t) +} + +func (t typeReference[T]) TrimmedURL() string { + return utils.TrimTypeURL(t.URL()) +} + +func (t typeReference[T]) NewCache() RawCache { + return NewCache[T]() +} + +func (t typeReference[T]) NewPrioritizedCache(prioritySlots int) []RawCache { + caches := NewPrioritizedCache[T](prioritySlots) + out := make([]RawCache, len(caches)) + for i, c := range caches { + out[i] = c + } + return out +} + +type wrappedHandler[T proto.Message] struct { + ads.RawSubscriptionHandler +} + +func (w wrappedHandler[T]) Notify(name string, r *ads.Resource[T], metadata ads.SubscriptionMetadata) { + var raw *ads.RawResource + if r != nil { + var err error + raw, err = r.Marshal() + if err != nil { + w.RawSubscriptionHandler.ResourceMarshalError(name, r.Resource, err) + return + } + } + w.RawSubscriptionHandler.Notify(name, raw, metadata) +} + +// toGenericHandler wraps the given RawSubscriptionHandler into a typed SubscriptionHandler. Multiple +// invocations of this function with the same RawSubscriptionHandler always return a semantically +// equivalent value, meaning it's possible to do the following, without needing to explicitly store +// and reuse the returned SubscriptionHandler: +// +// var c Cache[*ads.Endpoint] +// var rawHandler RawSubscriptionHandler +// c.Subscribe("foo", ToGenericHandler[*ads.Endpoint](rawHandler)) +// c.Unsubscribe("foo", ToGenericHandler[*ads.Endpoint](rawHandler)) +func (t typeReference[T]) toGenericHandler(raw ads.RawSubscriptionHandler) ads.SubscriptionHandler[T] { + return wrappedHandler[T]{raw} +} + +func (t typeReference[T]) isSubscribedTo(c RawCache, name string, handler ads.RawSubscriptionHandler) bool { + return c.(Cache[T]).IsSubscribedTo(name, t.toGenericHandler(handler)) +} + +func (t typeReference[T]) subscribe(c RawCache, name string, handler ads.RawSubscriptionHandler) { + c.(Cache[T]).Subscribe(name, t.toGenericHandler(handler)) +} + +func (t typeReference[T]) unsubscribe(c RawCache, name string, handler ads.RawSubscriptionHandler) { + c.(Cache[T]).Unsubscribe(name, t.toGenericHandler(handler)) +} + +// TypeOf returns a TypeReference that corresponds to the type parameter. +func TypeOf[T proto.Message]() TypeReference[T] { + return typeReference[T](utils.GetTypeURL[T]()) +} + +// IsSubscribedTo checks whether the given handler is subscribed to the given named resource by invoking +// the underlying generic API [diderot.Cache.IsSubscribedTo]. +func IsSubscribedTo(c RawCache, name string, handler ads.RawSubscriptionHandler) bool { + return c.Type().isSubscribedTo(c, name, handler) +} + +// Subscribe registers the handler as a subscriber of the given named resource by invoking the +// underlying generic API [diderot.Cache.Subscribe]. +func Subscribe(c RawCache, name string, handler ads.RawSubscriptionHandler) { + c.Type().subscribe(c, name, handler) +} + +// Unsubscribe unregisters the handler as a subscriber of the given named resource by invoking the +// underlying generic API [diderot.Cache.Unsubscribe]. +func Unsubscribe(c RawCache, name string, handler ads.RawSubscriptionHandler) { + c.Type().unsubscribe(c, name, handler) +} diff --git a/type_test.go b/type_test.go new file mode 100644 index 0000000..df759d3 --- /dev/null +++ b/type_test.go @@ -0,0 +1,51 @@ +package diderot + +import ( + "testing" + "time" + + "github.com/linkedin/diderot/ads" + "github.com/linkedin/diderot/testutils" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +func TestType(t *testing.T) { + tests := []struct { + Name string + UseRawSetter bool + }{ + { + Name: "typed", + UseRawSetter: false, + }, + { + Name: "raw", + UseRawSetter: true, + }, + } + + for _, test := range tests { + t.Run(test.Name, func(t *testing.T) { + c := NewCache[*wrapperspb.BoolValue]() + + const foo = "foo" + + r := &ads.Resource[*wrapperspb.BoolValue]{ + Name: foo, + Version: "0", + Resource: wrapperspb.Bool(true), + } + if test.UseRawSetter { + require.NoError(t, c.SetRaw(testutils.MustMarshal(t, r), time.Time{})) + } else { + c.SetResource(r, time.Time{}) + } + + testutils.ResourceEquals(t, r, c.Get(foo)) + + c.Clear(foo, time.Time{}) + require.Nil(t, c.Get(foo)) + }) + } +}