From 012f59d1b5fbe0cfd15f4fd09ffd8e23c51d0004 Mon Sep 17 00:00:00 2001 From: Jussi Maki Date: Mon, 3 Jun 2024 17:34:53 +0200 Subject: [PATCH] k8s: Reflector for reflecting ListerWatcher into StateDB table This implements a reflector from a client-go ListerWatcher into StateDB table. It is implemented on top of the client-go Reflector that the Informer uses. Example usage: var ( db *statedb.DB myTable statedb.RWTable[*someObject] lw cache.ListerWatcher // Same as with Resource[T] jobGroup job.Group ) cfg := k8s.ReflectorConfig[*someObject]{ Table: myTable, ListerWatcher: lw, } // Register a background job to the 'jobGroup' to start reflecting // to [myTable] with the [lw]. k8s.RegisterReflector[*someObject]( jobGroup, db, cfg ) // Initialized() returns true when initial listing is done. // E.g. this is replacement for the Resource[T] "Sync" event or // "CachesSynced" for !myTable.Initialized(db.ReadTxn()) { ... } // The objects can be queried: myTable.Get(db.ReadTxn(), myIndex.Query("foo")) // Resource[T]-style event stream is possible with Changes(): wtxn := db.WriteTxn(myTable) changes := myTable.Changes(wtxn, "foo") wtxn.Commit() for { for obj, rev, deleted, ok := changes.Next(); ok; obj, rev, deleted, ok = changes.Next() { // process change } // Wait for new changes to refresh the iterator. <-changes.Wait(db.ReadTxn()) } // The above iteration API can also be turned into an event stream for an almost // drop-in replacement for Resource[T] (there's no retrying though): var src stream.Observable[statedb.Change[*someObject]] = statedb.Observable(db, myTable) Signed-off-by: Jussi Maki --- pkg/k8s/statedb.go | 308 ++++++++++++++++++++++++++++++++++++++++ pkg/k8s/statedb_test.go | 149 +++++++++++++++++++ 2 files changed, 457 insertions(+) create mode 100644 pkg/k8s/statedb.go create mode 100644 pkg/k8s/statedb_test.go diff --git a/pkg/k8s/statedb.go b/pkg/k8s/statedb.go new file mode 100644 index 00000000000000..177f231bea0a1b --- /dev/null +++ b/pkg/k8s/statedb.go @@ -0,0 +1,308 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Authors of Cilium + +package k8s + +import ( + "context" + "fmt" + "time" + + "k8s.io/apimachinery/pkg/api/meta" + "k8s.io/client-go/tools/cache" + + "github.com/cilium/hive/cell" + "github.com/cilium/hive/job" + "github.com/cilium/statedb" + "github.com/cilium/stream" +) + +type ReflectorConfig[Obj any] struct { + // Maximum number of objects to commit in one transaction. Uses default if left zero. + BufferSize int + + // The amount of time to wait for the buffer to fill. Uses default if left zero. + BufferWaitTime time.Duration + + // The ListerWatcher to use to retrieve the objects + ListerWatcher cache.ListerWatcher + + // Optional function to transform the objects given by the ListerWatcher + Transform TransformFunc[Obj] + + // Optional function to query all objects. Used when replacing the objects on resync. + QueryAll QueryAllFunc[Obj] + + // The table to reflect to. + Table statedb.RWTable[Obj] +} + +// TransformFunc is an optional function to give to the Kubernetes reflector +// to transform the object returned by the ListerWatcher to the desired +// target object. If the function returns false the object is silently +// skipped. +type TransformFunc[Obj any] func(any) (obj Obj, ok bool) + +// QueryAllFunc is an optional function to give to the Kubernetes reflector +// to query all objects in the table that are managed by the reflector. +// It is used to delete all objects when the underlying cache.Reflector needs +// to Replace() all items for a resync. +type QueryAllFunc[Obj any] func(statedb.ReadTxn, statedb.Table[Obj]) statedb.Iterator[Obj] + +const ( + // DefaultBufferSize is the maximum number of objects to commit to the table in one write transaction. + DefaultBufferSize = 64 + + // DefaultBufferWaitTime is the amount of time to wait to fill the buffer before committing objects. + DefaultBufferWaitTime = 50 * time.Millisecond +) + +func (cfg ReflectorConfig[Obj]) getBufferSize() int { + if cfg.BufferSize == 0 { + return DefaultBufferSize + } + return cfg.BufferSize +} + +func (cfg ReflectorConfig[Obj]) getWaitTime() time.Duration { + if cfg.BufferWaitTime == 0 { + return DefaultBufferWaitTime + } + return cfg.BufferWaitTime +} + +// RegisterReflector registers a Kubernetes to StateDB table reflector. +func RegisterReflector[Obj any](jobGroup job.Group, db *statedb.DB, cfg ReflectorConfig[Obj]) { + // Register initializer that marks when the table has been initially populated, + // e.g. the initial "List" has concluded. + r := &k8sReflector[Obj]{ + ReflectorConfig: cfg, + db: db, + } + wtxn := db.WriteTxn(cfg.Table) + r.initDone = cfg.Table.RegisterInitializer(wtxn, "k8s-reflector") + wtxn.Commit() + + jobGroup.Add(job.OneShot( + fmt.Sprintf("k8s-reflector-[%T]", *new(Obj)), + r.run)) +} + +type k8sReflector[Obj any] struct { + ReflectorConfig[Obj] + + initDone func(statedb.WriteTxn) + db *statedb.DB +} + +func (r *k8sReflector[Obj]) run(ctx context.Context, health cell.Health) error { + type entry struct { + deleted bool + name string + namespace string + obj Obj + } + type buffer struct { + replaceItems []any + entries map[string]entry + } + bufferSize := r.getBufferSize() + waitTime := r.getWaitTime() + table := r.Table + + transform := r.Transform + if transform == nil { + // No provided transform function, use the identity function instead. + transform = TransformFunc[Obj](func(obj any) (Obj, bool) { return obj.(Obj), true }) + } + + queryAll := r.QueryAll + if queryAll == nil { + // No query function provided, use All() + queryAll = QueryAllFunc[Obj](func(txn statedb.ReadTxn, tbl statedb.Table[Obj]) statedb.Iterator[Obj] { + return tbl.All(txn) + }) + } + + // Construct a stream of K8s objects, buffered into chunks every [waitTime] period + // and then committed. + // This reduces the number of write transactions required and thus the number of times + // readers get woken up, which results in much better overall throughput. + src := stream.Buffer( + ListerWatcherToObservable(r.ListerWatcher), + bufferSize, + waitTime, + + // Buffer the events into a map, coalescing them by key. + func(buf *buffer, ev CacheStoreEvent) *buffer { + switch { + case ev.Kind == CacheStoreEventReplace: + return &buffer{ + replaceItems: ev.Obj.([]any), + entries: make(map[string]entry, bufferSize), // Forget prior entries + } + case buf == nil: + buf = &buffer{ + replaceItems: nil, + entries: make(map[string]entry, bufferSize), + } + } + + var entry entry + entry.deleted = ev.Kind == CacheStoreEventDelete + + var key string + if d, ok := ev.Obj.(cache.DeletedFinalStateUnknown); ok { + key = d.Key + var err error + entry.namespace, entry.name, err = cache.SplitMetaNamespaceKey(d.Key) + if err != nil { + panic(fmt.Sprintf("%T internal error: cache.SplitMetaNamespaceKey(%q) failed: %s", r, d.Key, err)) + } + entry.obj, ok = transform(d.Obj) + if !ok { + return buf + } + } else { + meta, err := meta.Accessor(ev.Obj) + if err != nil { + panic(fmt.Sprintf("%T internal error: meta.Accessor failed: %s", r, err)) + } + entry.name = meta.GetName() + if ns := meta.GetNamespace(); ns != "" { + key = ns + "/" + meta.GetName() + entry.namespace = ns + } else { + key = meta.GetName() + } + + var ok bool + entry.obj, ok = transform(ev.Obj) + if !ok { + return buf + } + } + buf.entries[key] = entry + return buf + }, + ) + + commitBuffer := func(buf *buffer) { + txn := r.db.WriteTxn(r.Table) + defer txn.Commit() + + if buf.replaceItems != nil { + iter := queryAll(txn, table) + for obj, _, ok := iter.Next(); ok; obj, _, ok = iter.Next() { + table.Delete(txn, obj) + } + for _, item := range buf.replaceItems { + if obj, ok := transform(item); ok { + table.Insert(txn, obj) + } + } + // Mark the table as initialized. Internally this has a sync.Once + // so safe to call multiple times. + r.initDone(txn) + } + + for _, entry := range buf.entries { + if !entry.deleted { + table.Insert(txn, entry.obj) + } else { + table.Delete(txn, entry.obj) + } + } + } + + errs := make(chan error) + src.Observe( + ctx, + commitBuffer, + func(err error) { + errs <- err + close(errs) + }, + ) + return <-errs +} + +// ListerWatcherToObservable turns a ListerWatcher into an observable using the +// client-go's Reflector. +func ListerWatcherToObservable(lw cache.ListerWatcher) stream.Observable[CacheStoreEvent] { + return stream.FuncObservable[CacheStoreEvent]( + func(ctx context.Context, next func(CacheStoreEvent), complete func(err error)) { + store := &cacheStoreListener{ + onAdd: func(obj any) { + next(CacheStoreEvent{CacheStoreEventAdd, obj}) + }, + onUpdate: func(obj any) { next(CacheStoreEvent{CacheStoreEventUpdate, obj}) }, + onDelete: func(obj any) { next(CacheStoreEvent{CacheStoreEventDelete, obj}) }, + onReplace: func(objs []any) { next(CacheStoreEvent{CacheStoreEventReplace, objs}) }, + } + reflector := cache.NewReflector(lw, nil, store, 0) + go func() { + reflector.Run(ctx.Done()) + complete(nil) + }() + }) +} + +type CacheStoreEventKind int + +const ( + CacheStoreEventAdd = CacheStoreEventKind(iota) + CacheStoreEventUpdate + CacheStoreEventDelete + CacheStoreEventReplace +) + +type CacheStoreEvent struct { + Kind CacheStoreEventKind + Obj any +} + +// cacheStoreListener implements the methods used by the cache reflector and +// calls the given handlers for added, updated and deleted objects. +type cacheStoreListener struct { + onAdd, onUpdate, onDelete func(any) + onReplace func([]any) +} + +func (s *cacheStoreListener) Add(obj interface{}) error { + s.onAdd(obj) + return nil +} + +func (s *cacheStoreListener) Update(obj interface{}) error { + s.onUpdate(obj) + return nil +} + +func (s *cacheStoreListener) Delete(obj interface{}) error { + s.onDelete(obj) + return nil +} + +func (s *cacheStoreListener) Replace(items []interface{}, resourceVersion string) error { + if items == nil { + // Always emit a non-nil slice for replace. + items = []interface{}{} + } + s.onReplace(items) + return nil +} + +// These methods are never called by cache.Reflector: + +func (*cacheStoreListener) Get(obj interface{}) (item interface{}, exists bool, err error) { + panic("unimplemented") +} +func (*cacheStoreListener) GetByKey(key string) (item interface{}, exists bool, err error) { + panic("unimplemented") +} +func (*cacheStoreListener) List() []interface{} { panic("unimplemented") } +func (*cacheStoreListener) ListKeys() []string { panic("unimplemented") } +func (*cacheStoreListener) Resync() error { panic("unimplemented") } + +var _ cache.Store = &cacheStoreListener{} diff --git a/pkg/k8s/statedb_test.go b/pkg/k8s/statedb_test.go new file mode 100644 index 00000000000000..ff647265dbf2ed --- /dev/null +++ b/pkg/k8s/statedb_test.go @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Authors of Cilium + +package k8s_test + +import ( + "context" + "testing" + "time" + + "github.com/cilium/hive/cell" + "github.com/cilium/hive/hivetest" + "github.com/cilium/statedb" + "github.com/cilium/statedb/index" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/cilium/cilium/pkg/hive" + "github.com/cilium/cilium/pkg/k8s" + k8sClient "github.com/cilium/cilium/pkg/k8s/client" + "github.com/cilium/cilium/pkg/k8s/utils" +) + +var ( + testNodeNameIndex = statedb.Index[*corev1.Node, string]{ + Name: "name", + FromObject: func(obj *corev1.Node) index.KeySet { + return index.NewKeySet(index.String(obj.Name)) + }, + FromKey: index.String, + Unique: true, + } +) + +func newTestNodeTable(db *statedb.DB) (statedb.RWTable[*corev1.Node], error) { + tbl, err := statedb.NewTable( + "test-nodes", + testNodeNameIndex, + ) + if err != nil { + return nil, err + } + return tbl, db.RegisterTable(tbl) +} + +func TestReflector(t *testing.T) { + var ( + nodeName = "some-node" + node = &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + ResourceVersion: "0", + }, + Status: corev1.NodeStatus{ + Phase: "init", + }, + } + fakeClient, cs = k8sClient.NewFakeClientset() + + db *statedb.DB + nodeTable statedb.Table[*corev1.Node] + ) + + // Create the initial version of the node. Do this before anything + // starts watching the resources to avoid a race. + fakeClient.KubernetesFakeClientset.Tracker().Create( + corev1.SchemeGroupVersion.WithResource("nodes"), + node.DeepCopy(), "") + + var testTimeout = 10 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + hive := hive.New( + cell.Provide(func() k8sClient.Clientset { return cs }), + cell.Module("test", "test", + cell.ProvidePrivate( + func(client k8sClient.Clientset, tbl statedb.RWTable[*corev1.Node]) k8s.ReflectorConfig[*corev1.Node] { + return k8s.ReflectorConfig[*corev1.Node]{ + BufferSize: 10, + BufferWaitTime: time.Millisecond, + ListerWatcher: utils.ListerWatcherFromTyped(client.CoreV1().Nodes()), + Transform: nil, + QueryAll: nil, + Table: tbl, + } + }, + newTestNodeTable, + ), + cell.Invoke( + k8s.RegisterReflector[*corev1.Node], + func(db_ *statedb.DB, nodeTable_ statedb.RWTable[*corev1.Node]) { + db = db_ + nodeTable = nodeTable_ + }), + ), + ) + + tlog := hivetest.Logger(t) + if err := hive.Start(tlog, ctx); err != nil { + t.Fatalf("hive.Start failed: %s", err) + } + + // Wait until the table has been initialized. + require.Eventually( + t, + func() bool { return nodeTable.Initialized(db.ReadTxn()) }, + time.Second, + 5*time.Millisecond) + + iter, watch := nodeTable.AllWatch(db.ReadTxn()) + nodes := statedb.Collect(iter) + require.Len(t, nodes, 1) + require.Equal(t, nodeName, nodes[0].Name) + + // Update the node and check that it updated. + node.Status.Phase = "update1" + node.ObjectMeta.ResourceVersion = "1" + fakeClient.KubernetesFakeClientset.Tracker().Update( + corev1.SchemeGroupVersion.WithResource("nodes"), + node.DeepCopy(), "") + + // Wait until updated. + <-watch + + iter, watch = nodeTable.AllWatch(db.ReadTxn()) + nodes = statedb.Collect(iter) + + require.Len(t, nodes, 1) + require.EqualValues(t, "update1", nodes[0].Status.Phase) + + // Finally delete the node + fakeClient.KubernetesFakeClientset.Tracker().Delete( + corev1.SchemeGroupVersion.WithResource("nodes"), + "", "some-node") + + <-watch + + iter, _ = nodeTable.AllWatch(db.ReadTxn()) + nodes = statedb.Collect(iter) + require.Len(t, nodes, 0) + + // Finally check that the hive stops correctly. Note that we're not doing this in a + // defer to avoid potentially deadlocking on the Fatal calls. + if err := hive.Stop(tlog, context.TODO()); err != nil { + t.Fatalf("hive.Stop failed: %s", err) + } +}