Skip to content

Commit

Permalink
Add ctx param to GetLocalNode
Browse files Browse the repository at this point in the history
Signed-off-by: Tom Pantelis <tompantelis@gmail.com>
  • Loading branch information
tpantelis authored and skitt committed Nov 25, 2024
1 parent b30261c commit b99f193
Show file tree
Hide file tree
Showing 15 changed files with 85 additions and 84 deletions.
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func main() {
return
}

gw, err := gateway.New(&gateway.Config{
gw, err := gateway.New(ctx, &gateway.Config{
LeaderElectionConfig: gateway.LeaderElectionConfig{
LeaseDuration: time.Duration(gwLeadershipConfig.LeaseDuration) * time.Second,
RenewDeadline: time.Duration(gwLeadershipConfig.RenewDeadline) * time.Second,
Expand Down
4 changes: 2 additions & 2 deletions pkg/endpoint/local_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,13 @@ func (l *Local) Create(ctx context.Context) error {
return err
}

func GetLocalSpec(submSpec *types.SubmarinerSpecification, k8sClient kubernetes.Interface,
func GetLocalSpec(ctx context.Context, submSpec *types.SubmarinerSpecification, k8sClient kubernetes.Interface,
airGappedDeployment bool,
) (*submv1.EndpointSpec, error) {
// We'll panic if submSpec is nil, this is intentional
privateIP := GetLocalIP()

gwNode, err := node.GetLocalNode(k8sClient)
gwNode, err := node.GetLocalNode(ctx, k8sClient)
if err != nil {
return nil, errors.Wrap(err, "getting information on the local node")
}
Expand Down
14 changes: 7 additions & 7 deletions pkg/endpoint/local_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ var _ = Describe("GetLocalSpec", func() {
})

It("should return a valid EndpointSpec object", func() {
spec, err := endpoint.GetLocalSpec(submSpec, client, false)
spec, err := endpoint.GetLocalSpec(context.TODO(), submSpec, client, false)

Expect(err).ToNot(HaveOccurred())
Expect(spec.ClusterID).To(Equal("east"))
Expand All @@ -116,7 +116,7 @@ var _ = Describe("GetLocalSpec", func() {
})

It("should return the udp-port backend config of the cluster", func() {
spec, err := endpoint.GetLocalSpec(submSpec, client, false)
spec, err := endpoint.GetLocalSpec(context.TODO(), submSpec, client, false)
Expect(err).ToNot(HaveOccurred())
Expect(spec.BackendConfig[testUDPPortLabel]).To(Equal(testClusterUDPPort))
})
Expand All @@ -128,14 +128,14 @@ var _ = Describe("GetLocalSpec", func() {
})

It("should return a valid EndpointSpec object", func() {
_, err := endpoint.GetLocalSpec(submSpec, client, false)
_, err := endpoint.GetLocalSpec(context.TODO(), submSpec, client, false)
Expect(err).ToNot(HaveOccurred())
})
})

When("the gateway node is not annotated with public-ip", func() {
It("should use empty public-ip in the endpoint object for air-gapped deployments", func() {
spec, err := endpoint.GetLocalSpec(submSpec, client, true)
spec, err := endpoint.GetLocalSpec(context.TODO(), submSpec, client, true)

Expect(err).ToNot(HaveOccurred())
Expect(spec.ClusterID).To(Equal("east"))
Expand All @@ -150,7 +150,7 @@ var _ = Describe("GetLocalSpec", func() {
})

It("should use the annotated public-ip for air-gapped deployments", func() {
spec, err := endpoint.GetLocalSpec(submSpec, client, true)
spec, err := endpoint.GetLocalSpec(context.TODO(), submSpec, client, true)

Expect(err).ToNot(HaveOccurred())
Expect(spec.PrivateIP).To(Equal(testPrivateIP))
Expand All @@ -164,7 +164,7 @@ var _ = Describe("GetLocalSpec", func() {
})

It("should set the HealthCheckIP", func() {
spec, err := endpoint.GetLocalSpec(submSpec, client, true)
spec, err := endpoint.GetLocalSpec(context.TODO(), submSpec, client, true)
Expect(err).ToNot(HaveOccurred())
Expect(spec.HealthCheckIP).To(Equal(cniInterfaceIP))
})
Expand All @@ -175,7 +175,7 @@ var _ = Describe("GetLocalSpec", func() {
})

It("should not set the HealthCheckIP", func() {
spec, err := endpoint.GetLocalSpec(submSpec, client, true)
spec, err := endpoint.GetLocalSpec(context.TODO(), submSpec, client, true)
Expect(err).ToNot(HaveOccurred())
Expect(spec.HealthCheckIP).To(BeEmpty())
})
Expand Down
4 changes: 2 additions & 2 deletions pkg/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ type gatewayType struct {

var logger = log.Logger{Logger: logf.Log.WithName("Gateway")}

func New(config *Config) (Interface, error) {
func New(ctx context.Context, config *Config) (Interface, error) {
logger.Info("Initializing the gateway engine")

g := &gatewayType{
Expand Down Expand Up @@ -144,7 +144,7 @@ func New(config *Config) (Interface, error) {
g.airGapped = os.Getenv("AIR_GAPPED_DEPLOYMENT") == "true"
logger.Infof("AIR_GAPPED_DEPLOYMENT is set to %t", g.airGapped)

localEndpointSpec, err := endpoint.GetLocalSpec(&g.Spec, g.KubeClient, g.airGapped)
localEndpointSpec, err := endpoint.GetLocalSpec(ctx, &g.Spec, g.KubeClient, g.airGapped)
if err != nil {
return nil, errors.Wrap(err, "error creating local endpoint object")
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/gateway/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func newTestDriver() *testDriver {
})

JustBeforeEach(func() {
gw, err := gateway.New(&t.config)
gw, err := gateway.New(context.TODO(), &t.config)
Expect(err).To(Succeed())

ctx, stop := context.WithCancel(context.Background())
Expand Down
40 changes: 19 additions & 21 deletions pkg/node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,49 +30,47 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/util/retry"
nodeutil "k8s.io/component-helpers/node/util"
logf "sigs.k8s.io/controller-runtime/pkg/log"
)

var logger = log.Logger{Logger: logf.Log.WithName("Node")}

var Retry = wait.Backoff{
Steps: 5,
Duration: 5 * time.Second,
Factor: 1.2,
Jitter: 0.1,
}
// These are public to allow unit tests to override.
var (
PollTimeout = time.Second * 30
PollInterval = time.Second
)

func GetLocalNode(clientset kubernetes.Interface) (*v1.Node, error) {
func GetLocalNode(ctx context.Context, clientset kubernetes.Interface) (*v1.Node, error) {
nodeName, ok := os.LookupEnv("NODE_NAME")
if !ok {
return nil, errors.New("error reading the NODE_NAME from the environment")
}

var node *v1.Node

err := retry.OnError(Retry, func(err error) bool {
logger.Warningf("Error reading the local node - retrying: %v", err)
return true
}, func() error {
var err error
err := wait.PollUntilContextTimeout(ctx, PollInterval, PollTimeout, true,
func(ctx context.Context) (bool, error) {
var err error

node, err = clientset.CoreV1().Nodes().Get(context.TODO(), nodeName, metav1.GetOptions{})
if err != nil {
return errors.Wrapf(err, "unable to find local node %q", nodeName)
}
node, err = clientset.CoreV1().Nodes().Get(ctx, nodeName, metav1.GetOptions{})
if err == nil {
return true, nil
}

return nil
})
logger.Warningf("Error retrieving the local node %q - retrying: %v", nodeName, err)

return false, nil
})

return node, errors.Wrapf(err, "failed to get local node %q", nodeName)
}

func WaitForLocalNodeReady(ctx context.Context, client kubernetes.Interface) {
// In most cases the node will already be ready; otherwise, wait forever or until the context is cancelled.
err := wait.PollUntilContextCancel(ctx, time.Second, true, func(_ context.Context) (bool, error) {
localNode, err := GetLocalNode(client) //nolint:contextcheck // TODO - should pass the context parameter
err := wait.PollUntilContextCancel(ctx, time.Second, true, func(ctx context.Context) (bool, error) {
localNode, err := GetLocalNode(ctx, client)

if err != nil {
logger.Error(err, "Error retrieving local node")
Expand Down
15 changes: 6 additions & 9 deletions pkg/node/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
corev1 "k8s.io/api/core/v1"
v1meta "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/wait"
fakeK8s "k8s.io/client-go/kubernetes/fake"
nodeutil "k8s.io/component-helpers/node/util"
)
Expand All @@ -42,7 +41,7 @@ var _ = Describe("GetLocalNode", func() {

When("the local Node resource exists", func() {
It("should return the resource", func() {
Expect(node.GetLocalNode(t.client)).To(Equal(t.node))
Expect(node.GetLocalNode(context.TODO(), t.client)).To(Equal(t.node))
})
})

Expand All @@ -52,7 +51,7 @@ var _ = Describe("GetLocalNode", func() {
})

It("should return an error", func() {
_, err := node.GetLocalNode(t.client)
_, err := node.GetLocalNode(context.TODO(), t.client)
Expect(err).To(HaveOccurred())
})
})
Expand All @@ -63,7 +62,7 @@ var _ = Describe("GetLocalNode", func() {
})

It("should eventually return the resource", func() {
Expect(node.GetLocalNode(t.client)).To(Equal(t.node))
Expect(node.GetLocalNode(context.TODO(), t.client)).To(Equal(t.node))
})
})

Expand All @@ -73,7 +72,7 @@ var _ = Describe("GetLocalNode", func() {
})

It("should return an error", func() {
_, err := node.GetLocalNode(t.client)
_, err := node.GetLocalNode(context.TODO(), t.client)
Expect(err).To(HaveOccurred())
})
})
Expand Down Expand Up @@ -133,10 +132,8 @@ func newTestDriver() *testDriver {
t := &testDriver{}

BeforeEach(func() {
node.Retry = wait.Backoff{
Steps: 2,
Duration: 10 * time.Millisecond,
}
node.PollTimeout = 30 * time.Millisecond
node.PollInterval = 10 * time.Millisecond

t.node = &corev1.Node{
ObjectMeta: v1meta.ObjectMeta{
Expand Down
33 changes: 17 additions & 16 deletions pkg/routeagent_driver/handlers/ovn/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ func NewConnectionHandler(k8sClientset clientset.Interface, dynamicClient dynami
}
}

func (c *ConnectionHandler) initClients(newOVSDBClient NewOVSDBClientFn) error {
func (c *ConnectionHandler) initClients(ctx context.Context, newOVSDBClient NewOVSDBClientFn) error {
// Create nbdb client
nbdbModel, err := nbdb.FullDatabaseModel()
if err != nil {
return errors.Wrap(err, "error getting OVN NBDB database model")
}

c.nbdb, err = c.createLibovsdbClient(nbdbModel, newOVSDBClient)
c.nbdb, err = c.createLibovsdbClient(ctx, nbdbModel, newOVSDBClient)
if err != nil {
return errors.Wrap(err, "error creating NBDB connection")
}
Expand Down Expand Up @@ -96,7 +96,8 @@ func getOVNTLSConfig(pkFile, certFile, caFile string) (*tls.Config, error) {
}, nil
}

func (c *ConnectionHandler) createLibovsdbClient(dbModel model.ClientDBModel, newClient NewOVSDBClientFn) (libovsdbclient.Client, error) {
func (c *ConnectionHandler) createLibovsdbClient(ctx context.Context, dbModel model.ClientDBModel, newClient NewOVSDBClientFn,
) (libovsdbclient.Client, error) {
options := []libovsdbclient.Option{
// Reading and parsing the DB after reconnect at scale can (unsurprisingly)
// take longer than a normal ovsdb operation. Give it a bit more time so
Expand All @@ -105,7 +106,7 @@ func (c *ConnectionHandler) createLibovsdbClient(dbModel model.ClientDBModel, ne
libovsdbclient.WithLogger(&logger.Logger),
}

localNode, err := node.GetLocalNode(c.k8sClientset)
localNode, err := node.GetLocalNode(ctx, c.k8sClientset)
if err != nil {
return nil, errors.Wrap(err, "error getting the node")
}
Expand All @@ -114,15 +115,15 @@ func (c *ConnectionHandler) createLibovsdbClient(dbModel model.ClientDBModel, ne
// Will use empty zone if not found
zoneName := annotations[constants.OvnZoneAnnotation]

dbAddress, err := discoverOvnKubernetesNetwork(context.TODO(), c.k8sClientset, c.dynamicClient, zoneName)
dbAddress, err := discoverOvnKubernetesNetwork(ctx, c.k8sClientset, c.dynamicClient, zoneName)
if err != nil {
return nil, errors.Wrap(err, "error getting the OVN NBDB Address")
}

options = append(options, libovsdbclient.WithEndpoint(dbAddress))

if strings.HasPrefix(dbAddress, "ssl:") {
tlsConfig, err := getTLSConfig(c.k8sClientset)
tlsConfig, err := getTLSConfig(ctx, c.k8sClientset)
if err != nil {
return nil, err
}
Expand All @@ -135,19 +136,19 @@ func (c *ConnectionHandler) createLibovsdbClient(dbModel model.ClientDBModel, ne
return nil, errors.Wrap(err, "error creating ovsdbClient")
}

ctx, cancel := context.WithTimeout(context.Background(), ovsDBTimeout)
clientCtx, cancel := context.WithTimeout(ctx, ovsDBTimeout)
defer cancel()

err = client.Connect(ctx)
err = client.Connect(clientCtx)

err = errors.Wrap(err, "error connecting to ovsdb")
if err == nil {
if dbModel.Name() == "OVN_Northbound" {
_, err = client.MonitorAll(ctx)
_, err = client.MonitorAll(clientCtx)
err = errors.Wrap(err, "error setting OVN NBDB client to monitor-all")
} else {
// Only Monitor Required SBDB tables to reduce memory overhead
_, err = client.Monitor(ctx,
_, err = client.Monitor(clientCtx,
client.NewMonitor(
libovsdbclient.WithTable(&sbdb.Chassis{}),
),
Expand All @@ -164,23 +165,23 @@ func (c *ConnectionHandler) createLibovsdbClient(dbModel model.ClientDBModel, ne
return client, nil
}

func getFile(k8sClientset clientset.Interface, url string) (string, error) {
file, err := clusterfiles.Get(k8sClientset, url)
func getFile(ctx context.Context, k8sClientset clientset.Interface, url string) (string, error) {
file, err := clusterfiles.Get(ctx, k8sClientset, url)
return file, errors.Wrapf(err, "error getting config file for %q", url)
}

func getTLSConfig(k8sClientset clientset.Interface) (*tls.Config, error) {
certFile, err := getFile(k8sClientset, getOVNCertPath())
func getTLSConfig(ctx context.Context, k8sClientset clientset.Interface) (*tls.Config, error) {
certFile, err := getFile(ctx, k8sClientset, getOVNCertPath())
if err != nil {
return nil, err
}

pkFile, err := getFile(k8sClientset, getOVNPrivKeyPath())
pkFile, err := getFile(ctx, k8sClientset, getOVNPrivKeyPath())
if err != nil {
return nil, err
}

caFile, err := getFile(k8sClientset, getOVNCaBundlePath())
caFile, err := getFile(ctx, k8sClientset, getOVNCaBundlePath())
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/routeagent_driver/handlers/ovn/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (ovn *Handler) GetNetworkPlugins() []string {
return []string{cni.OVNKubernetes}
}

func (ovn *Handler) Init(_ context.Context) error {
func (ovn *Handler) Init(ctx context.Context) error {
ovn.LegacyCleanup()

err := ovn.initIPtablesChains()
Expand All @@ -110,12 +110,12 @@ func (ovn *Handler) Init(_ context.Context) error {

connectionHandler := NewConnectionHandler(ovn.K8sClient, ovn.DynClient)

err = connectionHandler.initClients(ovn.NewOVSDBClient)
err = connectionHandler.initClients(ctx, ovn.NewOVSDBClient)
if err != nil {
return errors.Wrapf(err, "error getting connection handler to connect to OvnDB")
}

err = ovn.TransitSwitchIP.Init(ovn.K8sClient)
err = ovn.TransitSwitchIP.Init(ctx, ovn.K8sClient)
if err != nil {
return errors.Wrap(err, "error initializing TransitSwitchIP")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
package ovn_test

import (
"context"
"errors"
"os"

Expand All @@ -40,7 +41,7 @@ var _ = Describe("NonGatewayRouteHandler", func() {
JustBeforeEach(func() {
tsIP := ovn.NewTransitSwitchIP()
t.Start(ovn.NewNonGatewayRouteHandler(t.submClient, tsIP))
Expect(tsIP.Init(t.k8sClient)).To(Succeed())
Expect(tsIP.Init(context.TODO(), t.k8sClient)).To(Succeed())
})

awaitNonGatewayRoute := func(ep *submarinerv1.Endpoint) {
Expand Down
Loading

0 comments on commit b99f193

Please sign in to comment.