diff --git a/main.go b/main.go index dd6c544eb..5107b442c 100644 --- a/main.go +++ b/main.go @@ -132,7 +132,7 @@ func main() { return } - gw, err := gateway.New(&gateway.Config{ + gw, err := gateway.New(ctx, &gateway.Config{ LeaderElectionConfig: gateway.LeaderElectionConfig{ LeaseDuration: time.Duration(gwLeadershipConfig.LeaseDuration) * time.Second, RenewDeadline: time.Duration(gwLeadershipConfig.RenewDeadline) * time.Second, diff --git a/pkg/endpoint/local_endpoint.go b/pkg/endpoint/local_endpoint.go index 09f4ceec5..4d65663af 100644 --- a/pkg/endpoint/local_endpoint.go +++ b/pkg/endpoint/local_endpoint.go @@ -133,13 +133,13 @@ func (l *Local) Create(ctx context.Context) error { return err } -func GetLocalSpec(submSpec *types.SubmarinerSpecification, k8sClient kubernetes.Interface, +func GetLocalSpec(ctx context.Context, submSpec *types.SubmarinerSpecification, k8sClient kubernetes.Interface, airGappedDeployment bool, ) (*submv1.EndpointSpec, error) { // We'll panic if submSpec is nil, this is intentional privateIP := GetLocalIP() - gwNode, err := node.GetLocalNode(k8sClient) + gwNode, err := node.GetLocalNode(ctx, k8sClient) if err != nil { return nil, errors.Wrap(err, "getting information on the local node") } diff --git a/pkg/endpoint/local_endpoint_test.go b/pkg/endpoint/local_endpoint_test.go index dfcbe44f8..ba80d3c2d 100644 --- a/pkg/endpoint/local_endpoint_test.go +++ b/pkg/endpoint/local_endpoint_test.go @@ -96,7 +96,7 @@ var _ = Describe("GetLocalSpec", func() { }) It("should return a valid EndpointSpec object", func() { - spec, err := endpoint.GetLocalSpec(submSpec, client, false) + spec, err := endpoint.GetLocalSpec(context.TODO(), submSpec, client, false) Expect(err).ToNot(HaveOccurred()) Expect(spec.ClusterID).To(Equal("east")) @@ -116,7 +116,7 @@ var _ = Describe("GetLocalSpec", func() { }) It("should return the udp-port backend config of the cluster", func() { - spec, err := endpoint.GetLocalSpec(submSpec, client, false) + spec, err := endpoint.GetLocalSpec(context.TODO(), submSpec, client, false) Expect(err).ToNot(HaveOccurred()) Expect(spec.BackendConfig[testUDPPortLabel]).To(Equal(testClusterUDPPort)) }) @@ -128,14 +128,14 @@ var _ = Describe("GetLocalSpec", func() { }) It("should return a valid EndpointSpec object", func() { - _, err := endpoint.GetLocalSpec(submSpec, client, false) + _, err := endpoint.GetLocalSpec(context.TODO(), submSpec, client, false) Expect(err).ToNot(HaveOccurred()) }) }) When("the gateway node is not annotated with public-ip", func() { It("should use empty public-ip in the endpoint object for air-gapped deployments", func() { - spec, err := endpoint.GetLocalSpec(submSpec, client, true) + spec, err := endpoint.GetLocalSpec(context.TODO(), submSpec, client, true) Expect(err).ToNot(HaveOccurred()) Expect(spec.ClusterID).To(Equal("east")) @@ -150,7 +150,7 @@ var _ = Describe("GetLocalSpec", func() { }) It("should use the annotated public-ip for air-gapped deployments", func() { - spec, err := endpoint.GetLocalSpec(submSpec, client, true) + spec, err := endpoint.GetLocalSpec(context.TODO(), submSpec, client, true) Expect(err).ToNot(HaveOccurred()) Expect(spec.PrivateIP).To(Equal(testPrivateIP)) @@ -164,7 +164,7 @@ var _ = Describe("GetLocalSpec", func() { }) It("should set the HealthCheckIP", func() { - spec, err := endpoint.GetLocalSpec(submSpec, client, true) + spec, err := endpoint.GetLocalSpec(context.TODO(), submSpec, client, true) Expect(err).ToNot(HaveOccurred()) Expect(spec.HealthCheckIP).To(Equal(cniInterfaceIP)) }) @@ -175,7 +175,7 @@ var _ = Describe("GetLocalSpec", func() { }) It("should not set the HealthCheckIP", func() { - spec, err := endpoint.GetLocalSpec(submSpec, client, true) + spec, err := endpoint.GetLocalSpec(context.TODO(), submSpec, client, true) Expect(err).ToNot(HaveOccurred()) Expect(spec.HealthCheckIP).To(BeEmpty()) }) diff --git a/pkg/event/handler.go b/pkg/event/handler.go index 0386637a9..f794523af 100644 --- a/pkg/event/handler.go +++ b/pkg/event/handler.go @@ -19,6 +19,7 @@ limitations under the License. package event import ( + "context" "sync/atomic" submV1 "github.com/submariner-io/submariner/pkg/apis/submariner.io/v1" @@ -81,7 +82,7 @@ type NodeHandler interface { type Handler interface { // Init is called once on startup to let the handler initialize any state it needs. - Init() error + Init(ctx context.Context) error // SetHandlerState is called once on startup after Init with the HandlerState that can be used to access global data from event callbacks. SetState(handlerCtx HandlerState) @@ -106,7 +107,7 @@ type HandlerBase struct { handlerState atomic.Value } -func (ev *HandlerBase) Init() error { +func (ev *HandlerBase) Init(_ context.Context) error { return nil } diff --git a/pkg/event/registry.go b/pkg/event/registry.go index 771a54142..977ae68a2 100644 --- a/pkg/event/registry.go +++ b/pkg/event/registry.go @@ -19,6 +19,7 @@ limitations under the License. package event import ( + "context" "strings" "github.com/pkg/errors" @@ -39,7 +40,7 @@ var logger = log.Logger{Logger: logf.Log.WithName("EventRegistry")} // NewRegistry creates a new registry with the given name, typically referencing the owner, to manage event // Handlers that match the given networkPlugin name. The given event Handlers whose associated network plugin matches the given // networkPlugin name are added. Non-matching Handlers are ignored. Handlers will be called in registration order. -func NewRegistry(name, networkPlugin string, eventHandlers ...Handler) (*Registry, error) { +func NewRegistry(ctx context.Context, name, networkPlugin string, eventHandlers ...Handler) (*Registry, error) { r := &Registry{ name: name, networkPlugin: strings.ToLower(networkPlugin), @@ -47,7 +48,7 @@ func NewRegistry(name, networkPlugin string, eventHandlers ...Handler) (*Registr } for _, eventHandler := range eventHandlers { - err := r.addHandler(eventHandler) + err := r.addHandler(ctx, eventHandler) if err != nil { return nil, err } @@ -61,7 +62,7 @@ func (er *Registry) GetName() string { return er.name } -func (er *Registry) addHandler(eventHandler Handler) error { +func (er *Registry) addHandler(ctx context.Context, eventHandler Handler) error { evNetworkPlugins := set.New[string]() for _, np := range eventHandler.GetNetworkPlugins() { @@ -69,7 +70,7 @@ func (er *Registry) addHandler(eventHandler Handler) error { } if evNetworkPlugins.Has(AnyNetworkPlugin) || evNetworkPlugins.Has(er.networkPlugin) { - if err := eventHandler.Init(); err != nil { + if err := eventHandler.Init(ctx); err != nil { return errors.Wrapf(err, "Event handler %q failed to initialize", eventHandler.GetName()) } diff --git a/pkg/event/registry_test.go b/pkg/event/registry_test.go index 8ec307576..80cb29397 100644 --- a/pkg/event/registry_test.go +++ b/pkg/event/registry_test.go @@ -19,6 +19,8 @@ limitations under the License. package event_test import ( + "context" + . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" submv1 "github.com/submariner-io/submariner/pkg/apis/submariner.io/v1" @@ -55,8 +57,8 @@ var _ = Describe("Event Registry", func() { var err error - registry, err = event.NewRegistry("test-registry", npGenericKubeproxyIptables, logger.NewHandler(), matchingHandlers[0], - nonMatchingHandlers[0], matchingHandlers[1], matchingHandlers[2]) + registry, err = event.NewRegistry(context.TODO(), "test-registry", npGenericKubeproxyIptables, logger.NewHandler(), + matchingHandlers[0], nonMatchingHandlers[0], matchingHandlers[1], matchingHandlers[2]) Expect(err).NotTo(HaveOccurred()) Expect(registry.GetName()).To(Equal("test-registry")) }) @@ -115,7 +117,7 @@ var _ = Describe("Event Registry", func() { h := testing.NewTestHandler("test-handler", event.AnyNetworkPlugin, make(chan testing.TestEvent, 10)) h.FailOnEvent(testing.EvInit) - _, err := event.NewRegistry("test-registry", event.AnyNetworkPlugin, h) + _, err := event.NewRegistry(context.TODO(), "test-registry", event.AnyNetworkPlugin, h) Expect(err).To(HaveOccurred()) }) }) diff --git a/pkg/event/testing/controller_support.go b/pkg/event/testing/controller_support.go index 3760bc450..aa0671934 100644 --- a/pkg/event/testing/controller_support.go +++ b/pkg/event/testing/controller_support.go @@ -67,7 +67,7 @@ func (c *ControllerSupport) Start(handlers ...event.Handler) { networkPlugin = event.AnyNetworkPlugin } - registry, err := event.NewRegistry("test-registry", networkPlugin, handlers...) + registry, err := event.NewRegistry(context.TODO(), "test-registry", networkPlugin, handlers...) Expect(err).To(Succeed()) config := controller.Config{ diff --git a/pkg/event/testing/testing.go b/pkg/event/testing/testing.go index ca4eb9731..4927880cc 100644 --- a/pkg/event/testing/testing.go +++ b/pkg/event/testing/testing.go @@ -19,6 +19,7 @@ limitations under the License. package testing import ( + "context" "fmt" "sync" @@ -110,7 +111,7 @@ func (t *TestHandlerBase) addEvent(eventName string, param interface{}) error { return nil } -func (t *TestHandlerBase) Init() error { +func (t *TestHandlerBase) Init(_ context.Context) error { t.Initialized = true return t.checkFailOnEvent(EvInit) } diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index 8ffc088bc..a63e391b0 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -104,7 +104,7 @@ type gatewayType struct { var logger = log.Logger{Logger: logf.Log.WithName("Gateway")} -func New(config *Config) (Interface, error) { +func New(ctx context.Context, config *Config) (Interface, error) { logger.Info("Initializing the gateway engine") g := &gatewayType{ @@ -144,7 +144,7 @@ func New(config *Config) (Interface, error) { g.airGapped = os.Getenv("AIR_GAPPED_DEPLOYMENT") == "true" logger.Infof("AIR_GAPPED_DEPLOYMENT is set to %t", g.airGapped) - localEndpointSpec, err := endpoint.GetLocalSpec(&g.Spec, g.KubeClient, g.airGapped) + localEndpointSpec, err := endpoint.GetLocalSpec(ctx, &g.Spec, g.KubeClient, g.airGapped) if err != nil { return nil, errors.Wrap(err, "error creating local endpoint object") } diff --git a/pkg/gateway/gateway_test.go b/pkg/gateway/gateway_test.go index bf88e2bde..7a032d430 100644 --- a/pkg/gateway/gateway_test.go +++ b/pkg/gateway/gateway_test.go @@ -321,7 +321,7 @@ func newTestDriver() *testDriver { }) JustBeforeEach(func() { - gw, err := gateway.New(&t.config) + gw, err := gateway.New(context.TODO(), &t.config) Expect(err).To(Succeed()) ctx, stop := context.WithCancel(context.Background()) diff --git a/pkg/globalnet/controllers/gateway_monitor.go b/pkg/globalnet/controllers/gateway_monitor.go index 88483a373..4b9120335 100644 --- a/pkg/globalnet/controllers/gateway_monitor.go +++ b/pkg/globalnet/controllers/gateway_monitor.go @@ -53,7 +53,7 @@ import ( "k8s.io/utils/set" ) -func NewGatewayMonitor(config *GatewayMonitorConfig) (Interface, error) { +func NewGatewayMonitor(ctx context.Context, config *GatewayMonitorConfig) (Interface, error) { // We'll panic if config is nil, this is intentional gatewayMonitor := &gatewayMonitor{ baseController: newBaseController(), @@ -130,7 +130,12 @@ func NewGatewayMonitor(config *GatewayMonitorConfig) (Interface, error) { gatewayMonitor.gatewaySharedInformerStopCh = make(chan struct{}) - return &gatewayMonitorInterface{monitor: gatewayMonitor}, nil + registry, err := event.NewRegistry(ctx, "globalnet-registry", event.AnyNetworkPlugin, gatewayMonitor) + if err != nil { + return nil, errors.Wrap(err, "error creating event registry") + } + + return &gatewayMonitorInterface{monitor: gatewayMonitor, registry: registry}, nil } func (g *gatewayMonitorInterface) Start() error { @@ -145,13 +150,8 @@ func (g *gatewayMonitorInterface) Start() error { g.monitor.gatewaySharedInformer.Run(g.monitor.gatewaySharedInformerStopCh) }() - registry, err := event.NewRegistry("globalnet-registry", event.AnyNetworkPlugin, g.monitor) - if err != nil { - return errors.Wrap(err, "error creating event registry") - } - eventController, err := controller.New(&controller.Config{ - Registry: registry, + Registry: g.registry, RestMapper: g.monitor.RestMapper, Client: g.monitor.Client, Scheme: g.monitor.Scheme, @@ -175,7 +175,7 @@ func (g *gatewayMonitor) GetNetworkPlugins() []string { return []string{event.AnyNetworkPlugin} } -func (g *gatewayMonitor) Init() error { +func (g *gatewayMonitor) Init(_ context.Context) error { return g.createNATChain(constants.SmGlobalnetMarkChain) } diff --git a/pkg/globalnet/controllers/gateway_monitor_test.go b/pkg/globalnet/controllers/gateway_monitor_test.go index ce60bd013..71eaa9a13 100644 --- a/pkg/globalnet/controllers/gateway_monitor_test.go +++ b/pkg/globalnet/controllers/gateway_monitor_test.go @@ -416,7 +416,7 @@ func (t *gatewayMonitorTestDriver) start() { localSubnets := []string{localCIDR} - t.controller, err = controllers.NewGatewayMonitor(&controllers.GatewayMonitorConfig{ + t.controller, err = controllers.NewGatewayMonitor(context.TODO(), &controllers.GatewayMonitorConfig{ RestMapper: t.restMapper, Client: t.dynClient, Scheme: t.scheme, diff --git a/pkg/globalnet/controllers/types.go b/pkg/globalnet/controllers/types.go index 13b94b9b1..c0df0c94d 100644 --- a/pkg/globalnet/controllers/types.go +++ b/pkg/globalnet/controllers/types.go @@ -120,7 +120,8 @@ type LeaderElectionInfo struct { } type gatewayMonitorInterface struct { - monitor *gatewayMonitor + monitor *gatewayMonitor + registry *event.Registry } type gatewayMonitor struct { diff --git a/pkg/globalnet/main.go b/pkg/globalnet/main.go index 9ff12ff10..cbd234dd4 100644 --- a/pkg/globalnet/main.go +++ b/pkg/globalnet/main.go @@ -101,7 +101,7 @@ func main() { logger.Info("Starting submariner-globalnet", spec) // set up signals so we handle the first shutdown signal gracefully - stopCh := signals.SetupSignalHandler().Done() + ctx := signals.SetupSignalHandler() defer http.StartServer(http.Metrics|http.Profile, spec.MetricsPort)() @@ -139,7 +139,7 @@ func main() { clusterCIDRs := cidr.ExtractIPv4Subnets(localCluster.Spec.ClusterCIDR) - gatewayMonitor, err := controllers.NewGatewayMonitor(&controllers.GatewayMonitorConfig{ + gatewayMonitor, err := controllers.NewGatewayMonitor(ctx, &controllers.GatewayMonitorConfig{ Client: dynClient, RestMapper: restMapper, Scheme: scheme.Scheme, @@ -154,7 +154,7 @@ func main() { err = gatewayMonitor.Start() logger.FatalOnError(err, "Error starting the gatewayMonitor") - <-stopCh + <-ctx.Done() gatewayMonitor.Stop() logger.Infof("All controllers stopped or exited. Stopping main loop") diff --git a/pkg/node/node.go b/pkg/node/node.go index a215ecef1..5f17e47f4 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -30,21 +30,19 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/kubernetes" - "k8s.io/client-go/util/retry" nodeutil "k8s.io/component-helpers/node/util" logf "sigs.k8s.io/controller-runtime/pkg/log" ) var logger = log.Logger{Logger: logf.Log.WithName("Node")} -var Retry = wait.Backoff{ - Steps: 5, - Duration: 5 * time.Second, - Factor: 1.2, - Jitter: 0.1, -} +// These are public to allow unit tests to override. +var ( + PollTimeout = time.Second * 30 + PollInterval = time.Second +) -func GetLocalNode(clientset kubernetes.Interface) (*v1.Node, error) { +func GetLocalNode(ctx context.Context, clientset kubernetes.Interface) (*v1.Node, error) { nodeName, ok := os.LookupEnv("NODE_NAME") if !ok { return nil, errors.New("error reading the NODE_NAME from the environment") @@ -52,27 +50,27 @@ func GetLocalNode(clientset kubernetes.Interface) (*v1.Node, error) { var node *v1.Node - err := retry.OnError(Retry, func(err error) bool { - logger.Warningf("Error reading the local node - retrying: %v", err) - return true - }, func() error { - var err error + err := wait.PollUntilContextTimeout(ctx, PollInterval, PollTimeout, true, + func(ctx context.Context) (bool, error) { + var err error - node, err = clientset.CoreV1().Nodes().Get(context.TODO(), nodeName, metav1.GetOptions{}) - if err != nil { - return errors.Wrapf(err, "unable to find local node %q", nodeName) - } + node, err = clientset.CoreV1().Nodes().Get(ctx, nodeName, metav1.GetOptions{}) + if err == nil { + return true, nil + } - return nil - }) + logger.Warningf("Error retrieving the local node %q - retrying: %v", nodeName, err) + + return false, nil + }) return node, errors.Wrapf(err, "failed to get local node %q", nodeName) } func WaitForLocalNodeReady(ctx context.Context, client kubernetes.Interface) { // In most cases the node will already be ready; otherwise, wait forever or until the context is cancelled. - err := wait.PollUntilContextCancel(ctx, time.Second, true, func(_ context.Context) (bool, error) { - localNode, err := GetLocalNode(client) //nolint:contextcheck // TODO - should pass the context parameter + err := wait.PollUntilContextCancel(ctx, time.Second, true, func(ctx context.Context) (bool, error) { + localNode, err := GetLocalNode(ctx, client) if err != nil { logger.Error(err, "Error retrieving local node") diff --git a/pkg/node/node_test.go b/pkg/node/node_test.go index 6c15464ae..88c717b90 100644 --- a/pkg/node/node_test.go +++ b/pkg/node/node_test.go @@ -30,7 +30,6 @@ import ( corev1 "k8s.io/api/core/v1" v1meta "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/util/wait" fakeK8s "k8s.io/client-go/kubernetes/fake" nodeutil "k8s.io/component-helpers/node/util" ) @@ -42,7 +41,7 @@ var _ = Describe("GetLocalNode", func() { When("the local Node resource exists", func() { It("should return the resource", func() { - Expect(node.GetLocalNode(t.client)).To(Equal(t.node)) + Expect(node.GetLocalNode(context.TODO(), t.client)).To(Equal(t.node)) }) }) @@ -52,7 +51,7 @@ var _ = Describe("GetLocalNode", func() { }) It("should return an error", func() { - _, err := node.GetLocalNode(t.client) + _, err := node.GetLocalNode(context.TODO(), t.client) Expect(err).To(HaveOccurred()) }) }) @@ -63,7 +62,7 @@ var _ = Describe("GetLocalNode", func() { }) It("should eventually return the resource", func() { - Expect(node.GetLocalNode(t.client)).To(Equal(t.node)) + Expect(node.GetLocalNode(context.TODO(), t.client)).To(Equal(t.node)) }) }) @@ -73,7 +72,7 @@ var _ = Describe("GetLocalNode", func() { }) It("should return an error", func() { - _, err := node.GetLocalNode(t.client) + _, err := node.GetLocalNode(context.TODO(), t.client) Expect(err).To(HaveOccurred()) }) }) @@ -133,10 +132,8 @@ func newTestDriver() *testDriver { t := &testDriver{} BeforeEach(func() { - node.Retry = wait.Backoff{ - Steps: 2, - Duration: 10 * time.Millisecond, - } + node.PollTimeout = 30 * time.Millisecond + node.PollInterval = 10 * time.Millisecond t.node = &corev1.Node{ ObjectMeta: v1meta.ObjectMeta{ diff --git a/pkg/routeagent_driver/handlers/calico/ippool_handler.go b/pkg/routeagent_driver/handlers/calico/ippool_handler.go index 2e989728c..9b32aacb1 100644 --- a/pkg/routeagent_driver/handlers/calico/ippool_handler.go +++ b/pkg/routeagent_driver/handlers/calico/ippool_handler.go @@ -81,14 +81,14 @@ func (h *calicoIPPoolHandler) GetName() string { return "Calico IPPool handler" } -func (h *calicoIPPoolHandler) Init() error { +func (h *calicoIPPoolHandler) Init(ctx context.Context) error { var err error if h.client, err = NewClient(h.restConfig); err != nil { return errors.Wrap(err, "error initializing Calico clientset") } - return h.updateROKSCalicoCfg() + return h.updateROKSCalicoCfg(ctx) } func (h *calicoIPPoolHandler) RemoteEndpointCreated(endpoint *submV1.Endpoint) error { @@ -213,9 +213,9 @@ func getEndpointSubnetIPPoolName(endpoint *submV1.Endpoint, subnet string) strin return fmt.Sprintf("submariner-%s-%s", endpoint.Spec.ClusterID, strings.ReplaceAll(subnet, "/", "-")) } -func (h *calicoIPPoolHandler) platformIsROKS() (bool, error) { +func (h *calicoIPPoolHandler) platformIsROKS(ctx context.Context) (bool, error) { // Submariner GW is deployed on ROKS using LB service with specific annotations. - service, err := h.k8sClient.CoreV1().Services(h.namespace).Get(context.TODO(), GwLBSvcName, metav1.GetOptions{}) + service, err := h.k8sClient.CoreV1().Services(h.namespace).Get(ctx, GwLBSvcName, metav1.GetOptions{}) if apierrors.IsNotFound(err) { return false, nil } @@ -229,8 +229,8 @@ func (h *calicoIPPoolHandler) platformIsROKS() (bool, error) { // workaround to address datapath issue with default Calico IPPool configuration for ROKS platform, // IPIPMode of calico default IPPool should be set to 'Always'. -func (h *calicoIPPoolHandler) updateROKSCalicoCfg() error { - isROKS, err := h.platformIsROKS() +func (h *calicoIPPoolHandler) updateROKSCalicoCfg(ctx context.Context) error { + isROKS, err := h.platformIsROKS(ctx) if err != nil { return err } @@ -241,7 +241,7 @@ func (h *calicoIPPoolHandler) updateROKSCalicoCfg() error { // platform is ROKS, make sure that IPIPMode of default IPPool is Always - err = util.Update(context.TODO(), h.iPPoolResourceInterface(), &calicoapi.IPPool{ + err = util.Update(ctx, h.iPPoolResourceInterface(), &calicoapi.IPPool{ ObjectMeta: metav1.ObjectMeta{ Name: DefaultV4IPPoolName, }, diff --git a/pkg/routeagent_driver/handlers/healthchecker/healthchecker.go b/pkg/routeagent_driver/handlers/healthchecker/healthchecker.go index 8c9a6ac34..7a2cdd39a 100644 --- a/pkg/routeagent_driver/handlers/healthchecker/healthchecker.go +++ b/pkg/routeagent_driver/handlers/healthchecker/healthchecker.go @@ -170,7 +170,8 @@ func (h *controller) RemoteEndpointRemoved(endpoint *submarinerv1.Endpoint) erro return nil } -func (h *controller) Init() error { +func (h *controller) Init(_ context.Context) error { + //nolint:contextcheck // Ignore "should pass the context parameter" go func() { wait.Until(func() { h.Lock() diff --git a/pkg/routeagent_driver/handlers/kubeproxy/kp_packetfilter.go b/pkg/routeagent_driver/handlers/kubeproxy/kp_packetfilter.go index d79033a57..b6b29e5be 100644 --- a/pkg/routeagent_driver/handlers/kubeproxy/kp_packetfilter.go +++ b/pkg/routeagent_driver/handlers/kubeproxy/kp_packetfilter.go @@ -19,6 +19,7 @@ limitations under the License. package kubeproxy import ( + "context" "net" "os" "time" @@ -101,7 +102,7 @@ var discoverCNIRetryConfig = wait.Backoff{ Steps: 12, } -func (kp *SyncHandler) Init() error { +func (kp *SyncHandler) Init(_ context.Context) error { var err error var cniIface *cni.Interface diff --git a/pkg/routeagent_driver/handlers/mtu/mtuhandler.go b/pkg/routeagent_driver/handlers/mtu/mtuhandler.go index 772fb4eb5..53a067388 100644 --- a/pkg/routeagent_driver/handlers/mtu/mtuhandler.go +++ b/pkg/routeagent_driver/handlers/mtu/mtuhandler.go @@ -19,6 +19,7 @@ limitations under the License. package mtu import ( + "context" "strconv" "github.com/pkg/errors" @@ -81,7 +82,7 @@ func (h *mtuHandler) GetName() string { return "MTU handler" } -func (h *mtuHandler) Init() error { +func (h *mtuHandler) Init(_ context.Context) error { var err error h.pFilter, err = packetfilter.New() diff --git a/pkg/routeagent_driver/handlers/mtu/mtuhandler_test.go b/pkg/routeagent_driver/handlers/mtu/mtuhandler_test.go index 49ec694e0..ac3859086 100644 --- a/pkg/routeagent_driver/handlers/mtu/mtuhandler_test.go +++ b/pkg/routeagent_driver/handlers/mtu/mtuhandler_test.go @@ -18,6 +18,7 @@ limitations under the License. package mtu_test import ( + "context" "strconv" . "github.com/onsi/ginkgo/v2" @@ -140,7 +141,7 @@ func newTestDriver() *testDriver { JustBeforeEach(func() { t.handler = mtu.NewMTUHandler([]string{localCIDR}, t.isGlobalnet, t.tcpMssValue) - Expect(t.handler.Init()).To(Succeed()) + Expect(t.handler.Init(context.TODO())).To(Succeed()) }) return t diff --git a/pkg/routeagent_driver/handlers/ovn/connection.go b/pkg/routeagent_driver/handlers/ovn/connection.go index 06eb510f0..09e65757a 100644 --- a/pkg/routeagent_driver/handlers/ovn/connection.go +++ b/pkg/routeagent_driver/handlers/ovn/connection.go @@ -58,14 +58,14 @@ func NewConnectionHandler(k8sClientset clientset.Interface, dynamicClient dynami } } -func (c *ConnectionHandler) initClients(newOVSDBClient NewOVSDBClientFn) error { +func (c *ConnectionHandler) initClients(ctx context.Context, newOVSDBClient NewOVSDBClientFn) error { // Create nbdb client nbdbModel, err := nbdb.FullDatabaseModel() if err != nil { return errors.Wrap(err, "error getting OVN NBDB database model") } - c.nbdb, err = c.createLibovsdbClient(nbdbModel, newOVSDBClient) + c.nbdb, err = c.createLibovsdbClient(ctx, nbdbModel, newOVSDBClient) if err != nil { return errors.Wrap(err, "error creating NBDB connection") } @@ -96,7 +96,8 @@ func getOVNTLSConfig(pkFile, certFile, caFile string) (*tls.Config, error) { }, nil } -func (c *ConnectionHandler) createLibovsdbClient(dbModel model.ClientDBModel, newClient NewOVSDBClientFn) (libovsdbclient.Client, error) { +func (c *ConnectionHandler) createLibovsdbClient(ctx context.Context, dbModel model.ClientDBModel, newClient NewOVSDBClientFn, +) (libovsdbclient.Client, error) { options := []libovsdbclient.Option{ // Reading and parsing the DB after reconnect at scale can (unsurprisingly) // take longer than a normal ovsdb operation. Give it a bit more time so @@ -105,7 +106,7 @@ func (c *ConnectionHandler) createLibovsdbClient(dbModel model.ClientDBModel, ne libovsdbclient.WithLogger(&logger.Logger), } - localNode, err := node.GetLocalNode(c.k8sClientset) + localNode, err := node.GetLocalNode(ctx, c.k8sClientset) if err != nil { return nil, errors.Wrap(err, "error getting the node") } @@ -114,7 +115,7 @@ func (c *ConnectionHandler) createLibovsdbClient(dbModel model.ClientDBModel, ne // Will use empty zone if not found zoneName := annotations[constants.OvnZoneAnnotation] - dbAddress, err := discoverOvnKubernetesNetwork(context.TODO(), c.k8sClientset, c.dynamicClient, zoneName) + dbAddress, err := discoverOvnKubernetesNetwork(ctx, c.k8sClientset, c.dynamicClient, zoneName) if err != nil { return nil, errors.Wrap(err, "error getting the OVN NBDB Address") } @@ -122,7 +123,7 @@ func (c *ConnectionHandler) createLibovsdbClient(dbModel model.ClientDBModel, ne options = append(options, libovsdbclient.WithEndpoint(dbAddress)) if strings.HasPrefix(dbAddress, "ssl:") { - tlsConfig, err := getTLSConfig(c.k8sClientset) + tlsConfig, err := getTLSConfig(ctx, c.k8sClientset) if err != nil { return nil, err } @@ -135,19 +136,19 @@ func (c *ConnectionHandler) createLibovsdbClient(dbModel model.ClientDBModel, ne return nil, errors.Wrap(err, "error creating ovsdbClient") } - ctx, cancel := context.WithTimeout(context.Background(), ovsDBTimeout) + clientCtx, cancel := context.WithTimeout(ctx, ovsDBTimeout) defer cancel() - err = client.Connect(ctx) + err = client.Connect(clientCtx) err = errors.Wrap(err, "error connecting to ovsdb") if err == nil { if dbModel.Name() == "OVN_Northbound" { - _, err = client.MonitorAll(ctx) + _, err = client.MonitorAll(clientCtx) err = errors.Wrap(err, "error setting OVN NBDB client to monitor-all") } else { // Only Monitor Required SBDB tables to reduce memory overhead - _, err = client.Monitor(ctx, + _, err = client.Monitor(clientCtx, client.NewMonitor( libovsdbclient.WithTable(&sbdb.Chassis{}), ), @@ -164,23 +165,23 @@ func (c *ConnectionHandler) createLibovsdbClient(dbModel model.ClientDBModel, ne return client, nil } -func getFile(k8sClientset clientset.Interface, url string) (string, error) { - file, err := clusterfiles.Get(k8sClientset, url) +func getFile(ctx context.Context, k8sClientset clientset.Interface, url string) (string, error) { + file, err := clusterfiles.Get(ctx, k8sClientset, url) return file, errors.Wrapf(err, "error getting config file for %q", url) } -func getTLSConfig(k8sClientset clientset.Interface) (*tls.Config, error) { - certFile, err := getFile(k8sClientset, getOVNCertPath()) +func getTLSConfig(ctx context.Context, k8sClientset clientset.Interface) (*tls.Config, error) { + certFile, err := getFile(ctx, k8sClientset, getOVNCertPath()) if err != nil { return nil, err } - pkFile, err := getFile(k8sClientset, getOVNPrivKeyPath()) + pkFile, err := getFile(ctx, k8sClientset, getOVNPrivKeyPath()) if err != nil { return nil, err } - caFile, err := getFile(k8sClientset, getOVNCaBundlePath()) + caFile, err := getFile(ctx, k8sClientset, getOVNCaBundlePath()) if err != nil { return nil, err } diff --git a/pkg/routeagent_driver/handlers/ovn/gateway_route_handler.go b/pkg/routeagent_driver/handlers/ovn/gateway_route_handler.go index 34b2fca36..cdec24254 100644 --- a/pkg/routeagent_driver/handlers/ovn/gateway_route_handler.go +++ b/pkg/routeagent_driver/handlers/ovn/gateway_route_handler.go @@ -44,7 +44,7 @@ func NewGatewayRouteHandler(smClientSet submarinerClientset.Interface) *GatewayR } } -func (h *GatewayRouteHandler) Init() error { +func (h *GatewayRouteHandler) Init(_ context.Context) error { logger.Info("Starting GatewayRouteHandler") nextHopIP, err := getNextHopOnK8sMgmtIntf() diff --git a/pkg/routeagent_driver/handlers/ovn/handler.go b/pkg/routeagent_driver/handlers/ovn/handler.go index 2e1de1cad..df4fb8795 100644 --- a/pkg/routeagent_driver/handlers/ovn/handler.go +++ b/pkg/routeagent_driver/handlers/ovn/handler.go @@ -19,6 +19,7 @@ limitations under the License. package ovn import ( + "context" "net" "sync" @@ -97,7 +98,7 @@ func (ovn *Handler) GetNetworkPlugins() []string { return []string{cni.OVNKubernetes} } -func (ovn *Handler) Init() error { +func (ovn *Handler) Init(ctx context.Context) error { ovn.LegacyCleanup() err := ovn.initIPtablesChains() @@ -109,12 +110,12 @@ func (ovn *Handler) Init() error { connectionHandler := NewConnectionHandler(ovn.K8sClient, ovn.DynClient) - err = connectionHandler.initClients(ovn.NewOVSDBClient) + err = connectionHandler.initClients(ctx, ovn.NewOVSDBClient) if err != nil { return errors.Wrapf(err, "error getting connection handler to connect to OvnDB") } - err = ovn.TransitSwitchIP.Init(ovn.K8sClient) + err = ovn.TransitSwitchIP.Init(ctx, ovn.K8sClient) if err != nil { return errors.Wrap(err, "error initializing TransitSwitchIP") } diff --git a/pkg/routeagent_driver/handlers/ovn/non_gateway_route_handler.go b/pkg/routeagent_driver/handlers/ovn/non_gateway_route_handler.go index 884bf95dc..3d2aca8e0 100644 --- a/pkg/routeagent_driver/handlers/ovn/non_gateway_route_handler.go +++ b/pkg/routeagent_driver/handlers/ovn/non_gateway_route_handler.go @@ -48,7 +48,7 @@ func NewNonGatewayRouteHandler(smClient submarinerClientset.Interface, transitSw } } -func (h *NonGatewayRouteHandler) Init() error { +func (h *NonGatewayRouteHandler) Init(_ context.Context) error { logger.Info("Starting NonGatewayRouteHandler") return nil } diff --git a/pkg/routeagent_driver/handlers/ovn/non_gateway_route_handler_test.go b/pkg/routeagent_driver/handlers/ovn/non_gateway_route_handler_test.go index 081359989..9bad78cfa 100644 --- a/pkg/routeagent_driver/handlers/ovn/non_gateway_route_handler_test.go +++ b/pkg/routeagent_driver/handlers/ovn/non_gateway_route_handler_test.go @@ -19,6 +19,7 @@ limitations under the License. package ovn_test import ( + "context" "errors" "os" @@ -40,7 +41,7 @@ var _ = Describe("NonGatewayRouteHandler", func() { JustBeforeEach(func() { tsIP := ovn.NewTransitSwitchIP() t.Start(ovn.NewNonGatewayRouteHandler(t.submClient, tsIP)) - Expect(tsIP.Init(t.k8sClient)).To(Succeed()) + Expect(tsIP.Init(context.TODO(), t.k8sClient)).To(Succeed()) }) awaitNonGatewayRoute := func(ep *submarinerv1.Endpoint) { diff --git a/pkg/routeagent_driver/handlers/ovn/transit_switch_ip.go b/pkg/routeagent_driver/handlers/ovn/transit_switch_ip.go index d171902a9..8ffeed758 100644 --- a/pkg/routeagent_driver/handlers/ovn/transit_switch_ip.go +++ b/pkg/routeagent_driver/handlers/ovn/transit_switch_ip.go @@ -19,6 +19,7 @@ limitations under the License. package ovn import ( + "context" "os" "sync/atomic" @@ -35,7 +36,7 @@ type TransitSwitchIPGetter interface { type TransitSwitchIP interface { TransitSwitchIPGetter - Init(k8sClient kubernetes.Interface) error + Init(ctx context.Context, k8sClient kubernetes.Interface) error UpdateFrom(node *corev1.Node) (bool, error) } @@ -54,8 +55,8 @@ func (t *transitSwitchIPImpl) Get() string { return t.value.Load().(string) } -func (t *transitSwitchIPImpl) Init(k8sClient kubernetes.Interface) error { - node, err := nodeutil.GetLocalNode(k8sClient) +func (t *transitSwitchIPImpl) Init(ctx context.Context, k8sClient kubernetes.Interface) error { + node, err := nodeutil.GetLocalNode(ctx, k8sClient) if err != nil { return errors.Wrap(err, "error getting the local node") } diff --git a/pkg/routeagent_driver/handlers/ovn/transit_switch_ip_test.go b/pkg/routeagent_driver/handlers/ovn/transit_switch_ip_test.go index a94d4fdf7..ffd11eab3 100644 --- a/pkg/routeagent_driver/handlers/ovn/transit_switch_ip_test.go +++ b/pkg/routeagent_driver/handlers/ovn/transit_switch_ip_test.go @@ -57,14 +57,14 @@ var _ = Describe("TransitSwitchIP", func() { }) It("should set the TransitSwitchIP value", func() { - Expect(transitSwitchIP.Init(k8sClient)).To(Succeed()) + Expect(transitSwitchIP.Init(context.TODO(), k8sClient)).To(Succeed()) Expect(transitSwitchIP.Get()).To(Equal(nodeIP)) }) }) When("the node annotation does not exist", func() { It("should succeed and set an empty TransitSwitchIP value", func() { - Expect(transitSwitchIP.Init(k8sClient)).To(Succeed()) + Expect(transitSwitchIP.Init(context.TODO(), k8sClient)).To(Succeed()) Expect(transitSwitchIP.Get()).To(Equal("")) }) }) @@ -75,7 +75,7 @@ var _ = Describe("TransitSwitchIP", func() { }) It("should fail", func() { - Expect(transitSwitchIP.Init(k8sClient)).ToNot(Succeed()) + Expect(transitSwitchIP.Init(context.TODO(), k8sClient)).ToNot(Succeed()) }) }) @@ -87,7 +87,7 @@ var _ = Describe("TransitSwitchIP", func() { }) It("should fail", func() { - Expect(transitSwitchIP.Init(k8sClient)).ToNot(Succeed()) + Expect(transitSwitchIP.Init(context.TODO(), k8sClient)).ToNot(Succeed()) }) }) }) diff --git a/pkg/routeagent_driver/main.go b/pkg/routeagent_driver/main.go index 4af716d55..dda20c7e4 100644 --- a/pkg/routeagent_driver/main.go +++ b/pkg/routeagent_driver/main.go @@ -132,7 +132,7 @@ func main() { config := &watcher.Config{RestConfig: cfg} - localNode, err := node.GetLocalNode(k8sClientSet) + localNode, err := node.GetLocalNode(ctx, k8sClientSet) logger.FatalOnError(err, "Error getting information on the local node") healthcheckerConfig := &healthchecker.Config{ @@ -142,7 +142,7 @@ func main() { RouteAgentUpdateInterval: 60 * time.Second, } - registry, err := event.NewRegistry("routeagent_driver", np, + registry, err := event.NewRegistry(ctx, "routeagent_driver", np, kubeproxy.NewSyncHandler(env.ClusterCidr, env.ServiceCidr), ovn.NewHandler(&ovn.HandlerConfig{ Namespace: env.Namespace, diff --git a/pkg/util/clusterfiles/cluster_files.go b/pkg/util/clusterfiles/cluster_files.go index 9e93e15b7..a76214f54 100644 --- a/pkg/util/clusterfiles/cluster_files.go +++ b/pkg/util/clusterfiles/cluster_files.go @@ -38,7 +38,7 @@ var logger = log.Logger{Logger: logf.Log.WithName("ClusterFiles")} // using an url schema that supports configmap://// // secret://// and file:/// returning // a local path to the file. -func Get(k8sClient kubernetes.Interface, urlAddress string) (string, error) { +func Get(ctx context.Context, k8sClient kubernetes.Interface, urlAddress string) (string, error) { logger.V(log.DEBUG).Infof("Reading cluster_file: %s", urlAddress) parsedURL, err := url.Parse(urlAddress) @@ -61,7 +61,7 @@ func Get(k8sClient kubernetes.Interface, urlAddress string) (string, error) { return parsedURL.Path, nil case "secret": - secret, err := k8sClient.CoreV1().Secrets(namespace).Get(context.TODO(), pathContainerObject, metav1.GetOptions{}) + secret, err := k8sClient.CoreV1().Secrets(namespace).Get(ctx, pathContainerObject, metav1.GetOptions{}) if err != nil { return "", errors.Wrapf(err, "error reading secret %q from namespace %q", pathContainerObject, namespace) } @@ -74,7 +74,7 @@ func Get(k8sClient kubernetes.Interface, urlAddress string) (string, error) { } case "configmap": - configMap, err := k8sClient.CoreV1().ConfigMaps(namespace).Get(context.TODO(), pathContainerObject, metav1.GetOptions{}) + configMap, err := k8sClient.CoreV1().ConfigMaps(namespace).Get(ctx, pathContainerObject, metav1.GetOptions{}) if err != nil { return "", errors.Wrapf(err, "error reading configmap %q from namespace %q", pathContainerObject, namespace) } diff --git a/pkg/util/clusterfiles/cluster_files_test.go b/pkg/util/clusterfiles/cluster_files_test.go index 81c17dc0e..27bb179b1 100644 --- a/pkg/util/clusterfiles/cluster_files_test.go +++ b/pkg/util/clusterfiles/cluster_files_test.go @@ -19,6 +19,7 @@ limitations under the License. package clusterfiles_test import ( + "context" "os" "testing" @@ -45,6 +46,8 @@ var _ = BeforeSuite(func() { }) var _ = Describe("Cluster Files Get", func() { + ctx := context.TODO() + var client kubernetes.Interface BeforeEach(func() { client = fake.NewClientset( @@ -70,39 +73,39 @@ var _ = Describe("Cluster Files Get", func() { When("The scheme is unknown", func() { It("should return an error", func() { - _, err := clusterfiles.Get(client, "randomschema://ns1/my-secret-noo/data1") + _, err := clusterfiles.Get(ctx, client, "randomschema://ns1/my-secret-noo/data1") Expect(err).To(HaveOccurred()) }) }) When("a file source does not exist", func() { It("should return an error", func() { - _, err := clusterfiles.Get(client, "secret://ns1/my-secret-noo/data1") + _, err := clusterfiles.Get(ctx, client, "secret://ns1/my-secret-noo/data1") Expect(err).To(HaveOccurred()) - _, err = clusterfiles.Get(client, "configmap://ns1/my-configmap-noo/data1") + _, err = clusterfiles.Get(ctx, client, "configmap://ns1/my-configmap-noo/data1") Expect(err).To(HaveOccurred()) }) }) When("the content inside the file does not exist", func() { It("should return an error", func() { - _, err := clusterfiles.Get(client, "secret://ns1/my-secret/data1-does-not-exist") + _, err := clusterfiles.Get(ctx, client, "secret://ns1/my-secret/data1-does-not-exist") Expect(err).To(HaveOccurred()) }) }) When("the URL is malformed", func() { It("should return an error", func() { - _, err := clusterfiles.Get(client, "secret://ns1/") + _, err := clusterfiles.Get(ctx, client, "secret://ns1/") Expect(err).To(HaveOccurred()) - _, err = clusterfiles.Get(client, "secret://ns1/secret-with-no-content-detail") + _, err = clusterfiles.Get(ctx, client, "secret://ns1/secret-with-no-content-detail") Expect(err).To(HaveOccurred()) }) }) When("the source secret exist", func() { It("should return the data in a tmp file", func() { - file, err := clusterfiles.Get(client, "secret://ns1/my-secret/data1") + file, err := clusterfiles.Get(ctx, client, "secret://ns1/my-secret/data1") Expect(err).NotTo(HaveOccurred()) fileContent, err := os.ReadFile(file) Expect(err).NotTo(HaveOccurred()) @@ -112,7 +115,7 @@ var _ = Describe("Cluster Files Get", func() { When("the source configmap exist", func() { It("should return the data in a tmp file", func() { - file, err := clusterfiles.Get(client, "configmap://ns1/my-configmap/data1") + file, err := clusterfiles.Get(ctx, client, "configmap://ns1/my-configmap/data1") Expect(err).NotTo(HaveOccurred()) fileContent, err := os.ReadFile(file) Expect(err).NotTo(HaveOccurred()) @@ -122,7 +125,7 @@ var _ = Describe("Cluster Files Get", func() { When("the source configmap exist and has binary data", func() { It("should return the data in a tmp file", func() { - file, err := clusterfiles.Get(client, "configmap://ns1/my-configmap-binary/data1") + file, err := clusterfiles.Get(ctx, client, "configmap://ns1/my-configmap-binary/data1") Expect(err).NotTo(HaveOccurred()) fileContent, err := os.ReadFile(file) Expect(err).NotTo(HaveOccurred()) @@ -132,7 +135,7 @@ var _ = Describe("Cluster Files Get", func() { When("the source is a file", func() { It("should return the original path for the file:/// scheme", func() { - file, err := clusterfiles.Get(nil, "file:///dir/file") + file, err := clusterfiles.Get(ctx, nil, "file:///dir/file") Expect(err).NotTo(HaveOccurred()) Expect(file).To(Equal("/dir/file")) })