From 9a5063fe64d3c6c91106bc016a39286ddc30ee9b Mon Sep 17 00:00:00 2001 From: Tom Pantelis Date: Wed, 1 Mar 2023 13:24:15 -0500 Subject: [PATCH] Use resolver module in CoreDNS plugin Related to https://github.com/submariner-io/lighthouse/issues/214 Signed-off-by: Tom Pantelis --- coredns/gateway/controller.go | 39 +- coredns/gateway/controller_test.go | 8 +- coredns/plugin/handler.go | 24 +- coredns/plugin/handler_test.go | 623 +++++++++++------------- coredns/plugin/lighthouse.go | 26 +- coredns/plugin/record.go | 13 +- coredns/plugin/setup.go | 62 ++- coredns/plugin/setup_internal_test.go | 26 +- coredns/resolver/fake/cluster_status.go | 7 + 9 files changed, 363 insertions(+), 465 deletions(-) diff --git a/coredns/gateway/controller.go b/coredns/gateway/controller.go index e2f1aad60..b080e92bf 100644 --- a/coredns/gateway/controller.go +++ b/coredns/gateway/controller.go @@ -36,20 +36,13 @@ import ( "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/watch" "k8s.io/client-go/dynamic" - "k8s.io/client-go/rest" "k8s.io/client-go/tools/cache" logf "sigs.k8s.io/controller-runtime/pkg/log" ) var logger = log.Logger{Logger: logf.Log.WithName("Gateway")} -type NewClientsetFunc func(c *rest.Config) (dynamic.Interface, error) - -// NewClientset is an indirection hook for unit tests to supply fake client sets. -var NewClientset NewClientsetFunc - type Controller struct { - NewClientset NewClientsetFunc informer cache.Controller store cache.Store queue workqueue.Interface @@ -61,7 +54,6 @@ type Controller struct { func NewController() *Controller { controller := &Controller{ - NewClientset: getNewClientsetFunc(), queue: workqueue.New("Gateway Controller"), stopCh: make(chan struct{}), gatewayAvailable: true, @@ -77,18 +69,8 @@ func NewController() *Controller { return controller } -func getNewClientsetFunc() NewClientsetFunc { - if NewClientset != nil { - return NewClientset - } - - return func(c *rest.Config) (dynamic.Interface, error) { - return dynamic.NewForConfig(c) - } -} - -func (c *Controller) Start(kubeConfig *rest.Config) error { - gwClientset, err := c.getCheckedClientset(kubeConfig) +func (c *Controller) Start(client dynamic.Interface) error { + gwClientset, err := c.getCheckedClientset(client) if apierrors.IsNotFound(err) { logger.Infof("Connectivity component is not installed, disabling Gateway status controller") @@ -100,7 +82,7 @@ func (c *Controller) Start(kubeConfig *rest.Config) error { if err != nil { return err } - + logger.Infof("Starting Gateway status Controller") //nolint:wrapcheck // Let the caller wrap these errors. @@ -211,7 +193,7 @@ func (c *Controller) updateClusterStatusMap(connections []interface{}) { } func (c *Controller) updateLocalClusterIDIfNeeded(clusterID string) { - updateNeeded := clusterID != "" && clusterID != c.LocalClusterID() + updateNeeded := clusterID != "" && clusterID != c.GetLocalClusterID() if updateNeeded { logger.Infof("Updating the gateway localClusterID %q ", clusterID) c.localClusterID.Store(clusterID) @@ -264,15 +246,10 @@ func (c *Controller) getClusterStatusMap() map[string]bool { return c.clusterStatusMap.Load().(map[string]bool) } -func (c *Controller) getCheckedClientset(kubeConfig *rest.Config) (dynamic.ResourceInterface, error) { - clientSet, err := c.NewClientset(kubeConfig) - if err != nil { - return nil, errors.Wrap(err, "error creating client set") - } - +func (c *Controller) getCheckedClientset(client dynamic.Interface) (dynamic.ResourceInterface, error) { // First check if the Submariner resource is present. gvr, _ := schema.ParseResourceArg("submariners.v1alpha1.submariner.io") - list, err := clientSet.Resource(*gvr).Namespace(v1.NamespaceAll).List(context.TODO(), metav1.ListOptions{}) + list, err := client.Resource(*gvr).Namespace(v1.NamespaceAll).List(context.TODO(), metav1.ListOptions{}) if apierrors.IsNotFound(err) || meta.IsNoMatchError(err) || (err == nil && len(list.Items) == 0) { return nil, apierrors.NewNotFound(gvr.GroupResource(), "") } @@ -282,7 +259,7 @@ func (c *Controller) getCheckedClientset(kubeConfig *rest.Config) (dynamic.Resou } gvr, _ = schema.ParseResourceArg("gateways.v1.submariner.io") - gwClient := clientSet.Resource(*gvr).Namespace(v1.NamespaceAll) + gwClient := client.Resource(*gvr).Namespace(v1.NamespaceAll) _, err = gwClient.List(context.TODO(), metav1.ListOptions{}) if apierrors.IsNotFound(err) || meta.IsNoMatchError(err) { @@ -310,6 +287,6 @@ func (c *Controller) IsConnected(clusterID string) bool { return !c.gatewayAvailable || c.getClusterStatusMap()[clusterID] } -func (c *Controller) LocalClusterID() string { +func (c *Controller) GetLocalClusterID() string { return c.localClusterID.Load().(string) } diff --git a/coredns/gateway/controller_test.go b/coredns/gateway/controller_test.go index 546819969..9117f6cd7 100644 --- a/coredns/gateway/controller_test.go +++ b/coredns/gateway/controller_test.go @@ -35,7 +35,6 @@ import ( "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/dynamic" fakeClient "k8s.io/client-go/dynamic/fake" - "k8s.io/client-go/rest" ) const ( @@ -209,9 +208,6 @@ func newTestDiver() *testDriver { JustBeforeEach(func() { t.controller = gateway.NewController() - t.controller.NewClientset = func(c *rest.Config) (dynamic.Interface, error) { - return t.dynClient, nil - } if t.submarinerObj != nil { _, err := t.dynClient.Resource(submarinersGVR).Namespace("submariner-operator").Create(context.TODO(), t.submarinerObj, @@ -219,7 +215,7 @@ func newTestDiver() *testDriver { Expect(err).To(Succeed()) } - Expect(t.controller.Start(&rest.Config{})).To(Succeed()) + Expect(t.controller.Start(t.dynClient)).To(Succeed()) }) AfterEach(func() { @@ -259,7 +255,7 @@ func (t *testDriver) localClusterIDUpdateValidationTest(originalLocalClusterID, func (t *testDriver) awaitValidLocalClusterID(clusterID string) { Eventually(func() string { - return t.controller.LocalClusterID() + return t.controller.GetLocalClusterID() }, 5).Should(Equal(clusterID)) } diff --git a/coredns/plugin/handler.go b/coredns/plugin/handler.go index db47b3ddb..665dff244 100644 --- a/coredns/plugin/handler.go +++ b/coredns/plugin/handler.go @@ -26,7 +26,6 @@ import ( "github.com/coredns/coredns/plugin" "github.com/coredns/coredns/request" "github.com/miekg/dns" - "github.com/submariner-io/lighthouse/coredns/serviceimport" ) const PluginName = "lighthouse" @@ -70,25 +69,10 @@ func (lh *Lighthouse) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns func (lh *Lighthouse) getDNSRecord(ctx context.Context, zone string, state *request.Request, w dns.ResponseWriter, r *dns.Msg, pReq *recordRequest, ) (int, error) { - var isHeadless bool - var ( - dnsRecords []serviceimport.DNSRecord - found bool - record *serviceimport.DNSRecord - ) - - record, found = lh.getClusterIPForSvc(pReq) + dnsRecords, isHeadless, found := lh.Resolver.GetDNSRecords(pReq.namespace, pReq.service, pReq.cluster, pReq.hostname) if !found { - dnsRecords, found = lh.EndpointSlices.GetDNSRecords(pReq.hostname, pReq.cluster, pReq.namespace, - pReq.service, lh.ClusterStatus.IsConnected) - if !found { - log.Debugf("No record found for %q", state.QName()) - return lh.nextOrFailure(ctx, state, r, dns.RcodeNameError) - } - - isHeadless = true - } else if record != nil && record.IP != "" { - dnsRecords = append(dnsRecords, *record) + log.Debugf("No record found for %q", state.QName()) + return lh.nextOrFailure(ctx, state, r, dns.RcodeNameError) } if len(dnsRecords) == 0 { @@ -102,7 +86,7 @@ func (lh *Lighthouse) getDNSRecord(ctx context.Context, zone string, state *requ } // Count records - localClusterID := lh.ClusterStatus.LocalClusterID() + localClusterID := lh.ClusterStatus.GetLocalClusterID() for _, record := range dnsRecords { incDNSQueryCounter(localClusterID, record.ClusterName, pReq.service, pReq.namespace, record.IP) } diff --git a/coredns/plugin/handler_test.go b/coredns/plugin/handler_test.go index 304b429ba..23432177c 100644 --- a/coredns/plugin/handler_test.go +++ b/coredns/plugin/handler_test.go @@ -31,13 +31,14 @@ import ( . "github.com/onsi/gomega" "github.com/pkg/errors" "github.com/submariner-io/lighthouse/coredns/constants" - "github.com/submariner-io/lighthouse/coredns/endpointslice" lighthouse "github.com/submariner-io/lighthouse/coredns/plugin" - "github.com/submariner-io/lighthouse/coredns/serviceimport" + "github.com/submariner-io/lighthouse/coredns/resolver" + fakecs "github.com/submariner-io/lighthouse/coredns/resolver/fake" v1 "k8s.io/api/core/v1" discovery "k8s.io/api/discovery/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - fakeKubeClient "k8s.io/client-go/kubernetes/fake" + "k8s.io/client-go/dynamic/fake" + "k8s.io/client-go/kubernetes/scheme" mcsv1a1 "sigs.k8s.io/mcs-api/pkg/apis/v1alpha1" ) @@ -52,23 +53,31 @@ const ( clusterID2 = "cluster2" endpointIP = "100.96.157.101" endpointIP2 = "100.96.157.102" - portName1 = "http" - portName2 = "dns" - protocol1 = v1.ProtocolTCP - portNumber1 = int32(8080) - protocol2 = v1.ProtocolUDP - portNumber2 = int32(53) hostName1 = "hostName1" hostName2 = "hostName2" ) +var ( + port1 = mcsv1a1.ServicePort{ + Name: "http", + Protocol: v1.ProtocolTCP, + Port: 8080, + } + + port2 = mcsv1a1.ServicePort{ + Name: "dns", + Protocol: v1.ProtocolUDP, + Port: 53, + } +) + var _ = Describe("Lighthouse DNS plugin Handler", func() { Context("Fallthrough not configured", testWithoutFallback) Context("Fallthrough configured", testWithFallback) Context("Cluster connectivity status", testClusterStatus) Context("Headless services", testHeadlessService) Context("Local services", testLocalService) - Context("SRV records", testSRVMultiplePorts) + Context("Service with multiple ports", testSRVMultiplePorts) }) type FailingResponseWriter struct { @@ -76,39 +85,6 @@ type FailingResponseWriter struct { errorMsg string } -type MockClusterStatus struct { - clusterStatusMap map[string]bool - localClusterID string -} - -func NewMockClusterStatus() *MockClusterStatus { - return &MockClusterStatus{clusterStatusMap: make(map[string]bool), localClusterID: ""} -} - -func (m *MockClusterStatus) IsConnected(clusterID string) bool { - return m.clusterStatusMap[clusterID] -} - -type MockEndpointStatus struct { - endpointStatusMap map[string]bool -} - -func NewMockEndpointStatus() *MockEndpointStatus { - return &MockEndpointStatus{endpointStatusMap: make(map[string]bool)} -} - -func (m *MockEndpointStatus) IsHealthy(name, namespace, clusterID string) bool { - return m.endpointStatusMap[clusterID] -} - -func (m *MockClusterStatus) LocalClusterID() string { - return m.localClusterID -} - -func getKey(name, namespace string) string { - return namespace + "/" + name -} - func (w *FailingResponseWriter) WriteMsg(m *dns.Msg) error { return errors.New(w.errorMsg) } @@ -121,18 +97,20 @@ func testWithoutFallback() { BeforeEach(func() { t = newHandlerTestDriver() - t.mockCs.clusterStatusMap[clusterID] = true - t.mockEs.endpointStatusMap[clusterID] = true + t.mockCs.ConnectClusterID(clusterID) - t.lh.ServiceImports.Put(newServiceImport(namespace1, service1, clusterID, serviceIP, portName1, portNumber1, - protocol1, mcsv1a1.ClusterSetIP)) + t.lh.Resolver.PutServiceImport(newServiceImport(namespace1, service1, clusterID, serviceIP, mcsv1a1.ClusterSetIP, port1)) + + t.lh.Resolver.PutEndpointSlice(newEndpointSlice(namespace1, service1, clusterID, []mcsv1a1.ServicePort{port1}, + newEndpoint(serviceIP, "", true))) rec = dnstest.NewRecorder(&test.ResponseWriter{}) }) - When("DNS query for an existing service", func() { + Context("DNS query for an existing service", func() { qname := fmt.Sprintf("%s.%s.svc.clusterset.local.", service1, namespace1) - It("of Type A record should succeed and write an A record response", func() { + + Specify("of Type A record should succeed and write an A record response", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeA, @@ -142,21 +120,23 @@ func testWithoutFallback() { }, }) }) - It("of Type SRV should succeed and write an SRV record response", func() { + + Specify("of Type SRV should succeed and write an SRV record response", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, Answer: []dns.RR{ - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, portNumber1, qname)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port1.Port, qname)), }, }) }) }) - When("DNS query for an existing service in specific cluster", func() { + Context("DNS query for an existing service in a specific cluster", func() { qname := fmt.Sprintf("%s.%s.%s.svc.clusterset.local.", clusterID, service1, namespace1) - It("of Type A record should succeed and write an A record response", func() { + + Specify("of Type A record should succeed and write an A record response", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Rcode: dns.RcodeSuccess, @@ -167,23 +147,29 @@ func testWithoutFallback() { }) }) - It("of Type SRV should succeed and write an SRV record response", func() { + Specify("of Type SRV should succeed and write an SRV record response", func() { t.executeTestCase(rec, test.Case{ Qtype: dns.TypeSRV, Qname: qname, Rcode: dns.RcodeSuccess, Answer: []dns.RR{ - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, portNumber1, qname)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port1.Port, qname)), }, }) }) }) - When("DNS query for an existing service with a different namespace", func() { + Context("DNS query for an existing service with a different namespace", func() { qname := fmt.Sprintf("%s.%s.svc.clusterset.local.", service1, namespace2) - It("of Type A record should succeed and write an A record response", func() { - t.lh.ServiceImports.Put(newServiceImport(namespace2, service1, clusterID, serviceIP, portName1, - portNumber1, protocol1, mcsv1a1.ClusterSetIP)) + + BeforeEach(func() { + t.lh.Resolver.PutServiceImport(newServiceImport(namespace2, service1, clusterID, serviceIP, mcsv1a1.ClusterSetIP, port1)) + + t.lh.Resolver.PutEndpointSlice(newEndpointSlice(namespace2, service1, clusterID, []mcsv1a1.ServicePort{port1}, + newEndpoint(serviceIP, "", true))) + }) + + Specify("of Type A record should succeed and write an A record response", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeA, @@ -193,30 +179,31 @@ func testWithoutFallback() { }, }) }) - It("of Type SRV should succeed and write an SRV record response", func() { - t.lh.ServiceImports.Put(newServiceImport(namespace2, service1, clusterID, serviceIP, portName1, portNumber1, - protocol1, mcsv1a1.ClusterSetIP)) + + Specify("of Type SRV should succeed and write an SRV record response", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, Answer: []dns.RR{ - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, portNumber1, qname)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port1.Port, qname)), }, }) }) }) - When("DNS query for a non-existent service", func() { + Context("DNS query for a non-existent service", func() { qname := fmt.Sprintf("unknown.%s.svc.clusterset.local.", namespace1) - It("of Type A record should return RcodeNameError for A record query", func() { + + Specify("of Type A record should return RcodeNameError for A record query", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeA, Rcode: dns.RcodeNameError, }) }) - It("of Type SRV should return RcodeNameError for SRV record query", func() { + + Specify("of Type SRV should return RcodeNameError for SRV record query", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeSRV, @@ -225,16 +212,18 @@ func testWithoutFallback() { }) }) - When("DNS query for a non-existent service with a different namespace", func() { + Context("DNS query for a non-existent service with a different namespace", func() { qname := fmt.Sprintf("%s.%s.svc.clusterset.local.", service1, namespace2) - It("of Type A record should return RcodeNameError for A record query ", func() { + + Specify("of Type A record should return RcodeNameError for A record query ", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeA, Rcode: dns.RcodeNameError, }) }) - It("of Type SRV should return RcodeNameError for SRV record query ", func() { + + Specify("of Type SRV should return RcodeNameError for SRV record query ", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeSRV, @@ -243,16 +232,18 @@ func testWithoutFallback() { }) }) - When("DNS query for a pod", func() { + Context("DNS query for a pod", func() { qname := fmt.Sprintf("%s.%s.pod.clusterset.local.", service1, namespace1) - It("of Type A record should return RcodeNameError for A record query", func() { + + Specify("of Type A record should return RcodeNameError for A record query", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeA, Rcode: dns.RcodeNameError, }) }) - It("of Type SRV should return RcodeNameError for SRV record query", func() { + + Specify("of Type SRV should return RcodeNameError for SRV record query", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeSRV, @@ -261,16 +252,18 @@ func testWithoutFallback() { }) }) - When("DNS query for a non-existent zone", func() { + Context("DNS query for a non-existent zone", func() { qname := fmt.Sprintf("%s.%s.svc.cluster.east.", service1, namespace2) - It("of Type A record should return RcodeNameError for A record query", func() { + + Specify("of Type A record should return RcodeNameError for A record query", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeA, Rcode: dns.RcodeNotZone, }) }) - It("of Type SRV should return RcodeNameError for SRV record query", func() { + + Specify("of Type SRV should return RcodeNameError for SRV record query", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeSRV, @@ -279,9 +272,10 @@ func testWithoutFallback() { }) }) - When("type AAAA DNS query", func() { + Context("type AAAA DNS query", func() { qname := fmt.Sprintf("%s.%s.svc.clusterset.local.", service1, namespace1) - It("should return empty record", func() { + + Specify("should return empty record", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeAAAA, @@ -297,6 +291,7 @@ func testWithoutFallback() { }) qname := fmt.Sprintf("%s.%s.svc.clusterset.local.", service1, namespace1) + It("should return error RcodeServerFailure", func() { t.executeTestCase(rec, test.Case{ Qname: qname, @@ -315,9 +310,9 @@ func testWithFallback() { BeforeEach(func() { t = newHandlerTestDriver() - t.mockCs.clusterStatusMap[clusterID] = true - t.mockCs.localClusterID = clusterID - t.mockEs.endpointStatusMap[clusterID] = true + t.mockCs.ConnectClusterID(clusterID) + t.mockCs.SetLocalClusterID(clusterID) + t.lh.Fall = fall.F{Zones: []string{"clusterset.local."}} t.lh.Next = test.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { m := new(dns.Msg) @@ -326,15 +321,18 @@ func testWithFallback() { return dns.RcodeBadCookie, nil }) - t.lh.ServiceImports.Put(newServiceImport(namespace1, service1, clusterID, serviceIP, portName1, portNumber1, - protocol1, mcsv1a1.ClusterSetIP)) + t.lh.Resolver.PutServiceImport(newServiceImport(namespace1, service1, clusterID, serviceIP, mcsv1a1.ClusterSetIP, port1)) + + t.lh.Resolver.PutEndpointSlice(newEndpointSlice(namespace1, service1, clusterID, []mcsv1a1.ServicePort{port1}, + newEndpoint(serviceIP, "", true))) rec = dnstest.NewRecorder(&test.ResponseWriter{}) }) - When("type A DNS query for a non-matching lighthouse zone and matching fallthrough zone", func() { + Context("type A DNS query for a non-matching lighthouse zone and matching fallthrough zone", func() { qname := fmt.Sprintf("%s.%s.svc.cluster.east.", service1, namespace1) - It("should invoke the next plugin", func() { + + Specify("should invoke the next plugin", func() { t.lh.Fall = fall.F{Zones: []string{"clusterset.local.", "cluster.east."}} t.executeTestCase(rec, test.Case{ Qname: qname, @@ -344,9 +342,10 @@ func testWithFallback() { }) }) - When("type A DNS query for a non-matching lighthouse zone and non-matching fallthrough zone", func() { + Context("type A DNS query for a non-matching lighthouse zone and non-matching fallthrough zone", func() { qname := fmt.Sprintf("%s.%s.svc.cluster.east.", service1, namespace1) - It("should not invoke the next plugin", func() { + + Specify("should not invoke the next plugin", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeA, @@ -355,9 +354,10 @@ func testWithFallback() { }) }) - When("type AAAA DNS query", func() { + Context("type AAAA DNS query", func() { qname := fmt.Sprintf("%s.%s.svc.clusterset.local.", service1, namespace1) - It("should return empty record", func() { + + Specify("should return empty record", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeAAAA, @@ -367,9 +367,10 @@ func testWithFallback() { }) }) - When("type A DNS query for a pod", func() { + Context("type A DNS query for a pod", func() { qname := fmt.Sprintf("%s.%s.pod.clusterset.local.", service1, namespace1) - It("should invoke the next plugin", func() { + + Specify("should invoke the next plugin", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeA, @@ -378,8 +379,8 @@ func testWithFallback() { }) }) - When("type A DNS query for a non-existent service", func() { - It("should invoke the next plugin", func() { + Context("type A DNS query for a non-existent service", func() { + Specify("should invoke the next plugin", func() { t.executeTestCase(rec, test.Case{ Qname: fmt.Sprintf("unknown.%s.svc.clusterset.local.", namespace1), Qtype: dns.TypeA, @@ -388,8 +389,8 @@ func testWithFallback() { }) }) - When("type SRV DNS query for a non-matching lighthouse zone and matching fallthrough zone", func() { - It("should invoke the next plugin", func() { + Context("type SRV DNS query for a non-matching lighthouse zone and matching fallthrough zone", func() { + Specify("should invoke the next plugin", func() { t.lh.Fall = fall.F{Zones: []string{"clusterset.local.", "cluster.east."}} t.executeTestCase(rec, test.Case{ Qname: fmt.Sprintf("%s.%s.svc.cluster.east.", service1, namespace1), @@ -399,8 +400,8 @@ func testWithFallback() { }) }) - When("type SRV DNS query for a non-matching lighthouse zone and non-matching fallthrough zone", func() { - It("should not invoke the next plugin", func() { + Context("type SRV DNS query for a non-matching lighthouse zone and non-matching fallthrough zone", func() { + Specify("should not invoke the next plugin", func() { t.executeTestCase(rec, test.Case{ Qname: fmt.Sprintf("%s.%s.svc.cluster.east.", service1, namespace1), Qtype: dns.TypeSRV, @@ -409,8 +410,8 @@ func testWithFallback() { }) }) - When("type SRV DNS query for a pod", func() { - It("should invoke the next plugin", func() { + Context("type SRV DNS query for a pod", func() { + Specify("should invoke the next plugin", func() { t.executeTestCase(rec, test.Case{ Qname: fmt.Sprintf("%s.%s.pod.clusterset.local.", service1, namespace1), Qtype: dns.TypeSRV, @@ -419,8 +420,8 @@ func testWithFallback() { }) }) - When("type SRV DNS query for a non-existent service", func() { - It("should invoke the next plugin", func() { + Context("type SRV DNS query for a non-existent service", func() { + Specify("should invoke the next plugin", func() { t.executeTestCase(rec, test.Case{ Qname: fmt.Sprintf("unknown.%s.svc.clusterset.local.", namespace1), Qtype: dns.TypeSRV, @@ -438,24 +439,25 @@ func testClusterStatus() { BeforeEach(func() { t = newHandlerTestDriver() - t.mockCs.clusterStatusMap[clusterID] = true - t.mockCs.clusterStatusMap[clusterID2] = true - t.mockEs.endpointStatusMap[clusterID] = true - t.mockEs.endpointStatusMap[clusterID2] = true + t.mockCs.ConnectClusterID(clusterID) + t.mockCs.ConnectClusterID(clusterID2) + + t.lh.Resolver.PutServiceImport(newServiceImport(namespace1, service1, clusterID, serviceIP, mcsv1a1.ClusterSetIP, port1)) + + t.lh.Resolver.PutEndpointSlice(newEndpointSlice(namespace1, service1, clusterID, []mcsv1a1.ServicePort{port1}, + newEndpoint(serviceIP, "", true))) - t.lh.ServiceImports.Put(newServiceImport(namespace1, service1, clusterID, serviceIP, portName1, portNumber1, - protocol1, mcsv1a1.ClusterSetIP)) + t.lh.Resolver.PutServiceImport(newServiceImport(namespace1, service1, clusterID2, serviceIP2, mcsv1a1.ClusterSetIP, port1, port2)) + + t.lh.Resolver.PutEndpointSlice(newEndpointSlice(namespace1, service1, clusterID2, []mcsv1a1.ServicePort{port2}, + newEndpoint(serviceIP2, "", true))) rec = dnstest.NewRecorder(&test.ResponseWriter{}) }) - When("service is in two clusters and specific cluster is requested", func() { - JustBeforeEach(func() { - t.lh.ServiceImports.Put(newServiceImport(namespace1, service1, clusterID2, serviceIP2, portName2, - portNumber2, protocol2, mcsv1a1.ClusterSetIP)) - }) - + When("a service is in two clusters and specific cluster is requested", func() { qname := fmt.Sprintf("%s.%s.%s.svc.clusterset.local.", clusterID2, service1, namespace1) + It("should succeed and write that cluster's IP as A record response", func() { t.executeTestCase(rec, test.Case{ Qtype: dns.TypeA, @@ -467,46 +469,14 @@ func testClusterStatus() { }) }) - It("should succeed and write that cluster's IP as SRV record response", func() { - t.executeTestCase(rec, test.Case{ - Qname: qname, - Qtype: dns.TypeSRV, - Rcode: dns.RcodeSuccess, - Answer: []dns.RR{ - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, portNumber2, qname)), - }, - }) - }) - }) - - When("service is in two connected clusters and one is not of type ClusterSetIP", func() { - JustBeforeEach(func() { - t.initServiceImportMap() - - t.lh.ServiceImports.Put(newServiceImport(namespace1, service1, clusterID, serviceIP, portName1, portNumber1, - protocol1, mcsv1a1.ClusterSetIP)) - - t.lh.ServiceImports.Put(newServiceImport(namespace1, service1, clusterID2, serviceIP2, portName2, - portNumber2, protocol2, "")) - }) - qname := fmt.Sprintf("%s.%s.svc.clusterset.local.", service1, namespace1) - It("should succeed and write an A record response with the available IP", func() { - t.executeTestCase(rec, test.Case{ - Qname: qname, - Qtype: dns.TypeA, - Rcode: dns.RcodeSuccess, - Answer: []dns.RR{ - test.A(fmt.Sprintf("%s 5 IN A %s", qname, serviceIP)), - }, - }) - }) - It("should succeed and write that cluster's IP as SRV record response", func() { + It("should succeed and write that cluster's ports as SRV record response", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, Answer: []dns.RR{ - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, portNumber1, qname)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port2.Port, qname)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port1.Port, qname)), }, }) }) @@ -514,12 +484,11 @@ func testClusterStatus() { When("service is in two clusters and only one is connected", func() { JustBeforeEach(func() { - t.mockCs.clusterStatusMap[clusterID] = false - t.lh.ServiceImports.Put(newServiceImport(namespace1, service1, clusterID2, serviceIP2, portName1, - portNumber1, protocol1, mcsv1a1.ClusterSetIP)) + t.mockCs.DisconnectClusterID(clusterID) }) qname := fmt.Sprintf("%s.%s.svc.clusterset.local.", service1, namespace1) + It("should succeed and write an A record response", func() { t.executeTestCase(rec, test.Case{ Qname: qname, @@ -530,13 +499,14 @@ func testClusterStatus() { }, }) }) + It("should succeed and write an SRV record response", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, Answer: []dns.RR{ - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, portNumber1, qname)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port1.Port, qname)), }, }) }) @@ -544,13 +514,11 @@ func testClusterStatus() { When("service is present in two clusters and both are disconnected", func() { JustBeforeEach(func() { - t.mockCs.clusterStatusMap[clusterID] = false - t.mockCs.clusterStatusMap[clusterID2] = false - - t.lh.ServiceImports.Put(newServiceImport(namespace1, service1, clusterID2, serviceIP2, portName2, - portNumber2, protocol2, mcsv1a1.ClusterSetIP)) + t.mockCs.DisconnectAll() }) + qname := fmt.Sprintf("%s.%s.svc.clusterset.local.", service1, namespace1) + It("should return empty response (NODATA) for A record query", func() { t.executeTestCase(rec, test.Case{ Qname: qname, @@ -559,6 +527,7 @@ func testClusterStatus() { Answer: []dns.RR{}, }) }) + It("should return empty response (NODATA) for SRV record query", func() { t.executeTestCase(rec, test.Case{ Qname: qname, @@ -571,11 +540,13 @@ func testClusterStatus() { When("service is present in one cluster and it is disconnected", func() { JustBeforeEach(func() { - t.mockCs.clusterStatusMap[clusterID] = false - delete(t.mockCs.clusterStatusMap, clusterID2) - t.initServiceImportMap() + t.mockCs.DisconnectClusterID(clusterID) + + t.lh.Resolver.RemoveServiceImport(newServiceImport(namespace1, service1, clusterID2, serviceIP2, mcsv1a1.ClusterSetIP)) }) + qname := fmt.Sprintf("%s.%s.svc.clusterset.local.", service1, namespace1) + It("should return empty response (NODATA) for A record query", func() { t.executeTestCase(rec, test.Case{ Qname: qname, @@ -584,6 +555,7 @@ func testClusterStatus() { Answer: []dns.RR{}, }) }) + It("should return empty response (NODATA) for SRV record query", func() { t.executeTestCase(rec, test.Case{ Qname: qname, @@ -597,27 +569,30 @@ func testClusterStatus() { func testHeadlessService() { var ( - rec *dnstest.Recorder - t *handlerTestDriver + rec *dnstest.Recorder + t *handlerTestDriver + endpoints []discovery.Endpoint ) BeforeEach(func() { + endpoints = []discovery.Endpoint{} + t = newHandlerTestDriver() - t.mockCs.clusterStatusMap[clusterID] = true - t.mockCs.localClusterID = clusterID - t.mockEs.endpointStatusMap[clusterID] = true - t.mockEs.endpointStatusMap[clusterID2] = true + + t.mockCs.ConnectClusterID(clusterID) + + t.lh.Resolver.PutServiceImport(newServiceImport(namespace1, service1, clusterID, serviceIP, mcsv1a1.Headless, port1)) rec = dnstest.NewRecorder(&test.ResponseWriter{}) }) - When("headless service has no IPs", func() { - JustBeforeEach(func() { - t.lh.ServiceImports.Put(newServiceImport(namespace1, service1, clusterID, "", portName1, - portNumber1, protocol1, mcsv1a1.Headless)) - t.lh.EndpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, portName1, []string{}, []string{}, portNumber1, protocol1)) - }) + JustBeforeEach(func() { + t.lh.Resolver.PutEndpointSlice(newEndpointSlice(namespace1, service1, clusterID, []mcsv1a1.ServicePort{port1}, endpoints...)) + }) + + When("a headless service has no endpoints", func() { qname := fmt.Sprintf("%s.%s.svc.clusterset.local.", service1, namespace1) + It("should succeed and return empty response (NODATA)", func() { t.executeTestCase(rec, test.Case{ Qname: qname, @@ -626,6 +601,7 @@ func testHeadlessService() { Answer: []dns.RR{}, }) }) + It("should succeed and return empty response (NODATA)", func() { t.executeTestCase(rec, test.Case{ Qname: qname, @@ -635,14 +611,14 @@ func testHeadlessService() { }) }) }) - When("headless service has one IP", func() { - JustBeforeEach(func() { - t.lh.ServiceImports.Put(newServiceImport(namespace1, service1, clusterID, "", portName1, - portNumber1, protocol1, mcsv1a1.Headless)) - t.lh.EndpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, portName1, []string{hostName1}, []string{endpointIP}, - portNumber1, protocol1)) + + When("a headless service has one endpoint", func() { + BeforeEach(func() { + endpoints = append(endpoints, newEndpoint(endpointIP, hostName1, true)) }) + qname := fmt.Sprintf("%s.%s.svc.clusterset.local.", service1, namespace1) + It("should succeed and write an A record response", func() { t.executeTestCase(rec, test.Case{ Qname: qname, @@ -653,38 +629,39 @@ func testHeadlessService() { }, }) }) + It("should succeed and write an SRV record response", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, Answer: []dns.RR{ - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.%s", qname, portNumber1, hostName1, clusterID, qname)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.%s", qname, port1.Port, hostName1, clusterID, qname)), }, }) }) + It("should succeed and write an SRV record response for query with cluster name", func() { qname = fmt.Sprintf("%s.%s.%s.svc.clusterset.local.", clusterID, service1, namespace1) + t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, Answer: []dns.RR{ - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s", qname, portNumber1, hostName1, qname)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s", qname, port1.Port, hostName1, qname)), }, }) }) }) - When("headless service has two IPs", func() { - JustBeforeEach(func() { - t.lh.ServiceImports.Put(newServiceImport(namespace1, service1, clusterID, "", portName1, portNumber1, protocol1, - mcsv1a1.Headless)) - t.lh.EndpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, portName1, []string{hostName1, hostName2}, - []string{endpointIP, endpointIP2}, - portNumber1, protocol1)) + When("headless service has two endpoints", func() { + BeforeEach(func() { + endpoints = append(endpoints, newEndpoint(endpointIP, hostName1, true), newEndpoint(endpointIP2, hostName2, true)) }) + qname := fmt.Sprintf("%s.%s.svc.clusterset.local.", service1, namespace1) + It("should succeed and write two A records as response", func() { t.executeTestCase(rec, test.Case{ Qname: qname, @@ -696,61 +673,67 @@ func testHeadlessService() { }, }) }) + It("should succeed and write an SRV record response", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, Answer: []dns.RR{ - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.%s", qname, portNumber1, hostName1, clusterID, qname)), - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.%s", qname, portNumber1, hostName2, clusterID, qname)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.%s", qname, port1.Port, hostName1, clusterID, qname)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.%s", qname, port1.Port, hostName2, clusterID, qname)), }, }) }) + It("should succeed and write an SRV record response when port and protocol is queried", func() { - qname = fmt.Sprintf("%s.%s.%s.%s.svc.clusterset.local.", portName1, protocol1, service1, namespace1) + qname = fmt.Sprintf("%s.%s.%s.%s.svc.clusterset.local.", port1.Name, port1.Protocol, service1, namespace1) + t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, Answer: []dns.RR{ test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.%s.%s.svc.clusterset.local.", - qname, portNumber1, hostName1, clusterID, service1, namespace1)), + qname, port1.Port, hostName1, clusterID, service1, namespace1)), test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.%s.%s.svc.clusterset.local.", - qname, portNumber1, hostName2, clusterID, service1, namespace1)), + qname, port1.Port, hostName2, clusterID, service1, namespace1)), }, }) }) + It("should succeed and write an SRV record response when port and protocol is queried with underscore prefix", func() { - qname = fmt.Sprintf("_%s._%s.%s.%s.svc.clusterset.local.", portName1, protocol1, service1, namespace1) + qname = fmt.Sprintf("_%s._%s.%s.%s.svc.clusterset.local.", port1.Name, port1.Protocol, service1, namespace1) + t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, Answer: []dns.RR{ test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.%s.%s.svc.clusterset.local.", - qname, portNumber1, hostName1, clusterID, service1, namespace1)), + qname, port1.Port, hostName1, clusterID, service1, namespace1)), test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.%s.%s.svc.clusterset.local.", - qname, portNumber1, hostName2, clusterID, service1, namespace1)), + qname, port1.Port, hostName2, clusterID, service1, namespace1)), }, }) }) }) When("headless service is present in two clusters", func() { - JustBeforeEach(func() { - t.lh.ServiceImports.Put(newServiceImport(namespace1, service1, clusterID, "", portName1, - portNumber1, protocol1, mcsv1a1.Headless)) - t.lh.ServiceImports.Put(newServiceImport(namespace1, service1, clusterID2, "", portName1, - portNumber1, protocol1, mcsv1a1.Headless)) - t.lh.EndpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, portName1, []string{hostName1}, []string{endpointIP}, - portNumber1, protocol1)) - t.lh.EndpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID2, portName1, []string{hostName2}, []string{endpointIP2}, - portNumber1, protocol1)) - t.mockCs.clusterStatusMap[clusterID2] = true + BeforeEach(func() { + t.lh.Resolver.PutServiceImport(newServiceImport(namespace1, service1, clusterID2, serviceIP, mcsv1a1.Headless, port1)) + + t.lh.Resolver.PutEndpointSlice(newEndpointSlice(namespace1, service1, clusterID2, []mcsv1a1.ServicePort{port1}, + newEndpoint(endpointIP2, hostName2, true))) + + endpoints = append(endpoints, newEndpoint(endpointIP, hostName1, true)) + + t.mockCs.ConnectClusterID(clusterID2) }) - When("no cluster is requested", func() { + + Context("and no cluster is requested", func() { qname := fmt.Sprintf("%s.%s.svc.clusterset.local.", service1, namespace1) + It("should succeed and write all IPs as A records in response", func() { t.executeTestCase(rec, test.Case{ Qname: qname, @@ -763,8 +746,10 @@ func testHeadlessService() { }) }) }) - When("requested for a specific cluster", func() { + + Context("and a specific clusteris requested", func() { qname := fmt.Sprintf("%s.%s.%s.svc.clusterset.local.", clusterID, service1, namespace1) + It("should succeed and write the cluster's IP as A record in response", func() { t.executeTestCase(rec, test.Case{ Qname: qname, @@ -787,32 +772,30 @@ func testLocalService() { BeforeEach(func() { t = newHandlerTestDriver() - t.mockCs.clusterStatusMap[clusterID] = true - t.mockCs.clusterStatusMap[clusterID2] = true - t.mockEs.endpointStatusMap[clusterID] = true - t.mockEs.endpointStatusMap[clusterID2] = true - t.mockCs.localClusterID = clusterID - - localSI := newServiceImport(namespace1, service1, clusterID, serviceIP, portName1, portNumber1, protocol1, mcsv1a1.ClusterSetIP) - localSI.Spec.Ports = append(localSI.Spec.Ports, mcsv1a1.ServicePort{ - Name: portName2, - Protocol: protocol2, - Port: portNumber2, - }) + t.mockCs.ConnectClusterID(clusterID) + t.mockCs.ConnectClusterID(clusterID2) + + t.lh.Resolver.PutServiceImport(newServiceImport(namespace1, service1, clusterID, serviceIP, mcsv1a1.ClusterSetIP, port1)) - t.lh.ServiceImports.Put(newServiceImport(namespace1, service1, clusterID, serviceIP, portName1, portNumber1, - protocol1, mcsv1a1.ClusterSetIP)) + t.lh.Resolver.PutEndpointSlice(newEndpointSlice(namespace1, service1, clusterID, []mcsv1a1.ServicePort{port1}, + newEndpoint(serviceIP, "", true))) - t.lh.ServiceImports.Put(localSI) - t.lh.ServiceImports.Put(newServiceImport(namespace1, service1, clusterID2, serviceIP2, portName1, portNumber1, - protocol1, mcsv1a1.ClusterSetIP)) + t.lh.Resolver.PutServiceImport(newServiceImport(namespace1, service1, clusterID2, serviceIP2, mcsv1a1.ClusterSetIP, port1, port2)) + + t.lh.Resolver.PutEndpointSlice(newEndpointSlice(namespace1, service1, clusterID2, []mcsv1a1.ServicePort{port1, port2}, + newEndpoint(serviceIP2, "", true))) rec = dnstest.NewRecorder(&test.ResponseWriter{}) }) - When("service is in local and remote clusters", func() { + JustBeforeEach(func() { + t.mockCs.SetLocalClusterID(clusterID) + }) + + When("a service is in local and remote clusters", func() { qname := fmt.Sprintf("%s.%s.svc.clusterset.local.", service1, namespace1) - It("should succeed and write local cluster's IP as A record response", func() { + + It("should succeed and write the local cluster's IP as A record response", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeA, @@ -821,7 +804,7 @@ func testLocalService() { test.A(fmt.Sprintf("%s 5 IN A %s", qname, serviceIP)), }, }) - // Execute again to make sure not round robin + // Execute again to make sure no round robin t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeA, @@ -831,20 +814,22 @@ func testLocalService() { }, }) }) - It("should succeed and write local cluster's IP as SRV record response", func() { + + It("should succeed and write the local cluster's port as SRV record response", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, Answer: []dns.RR{ - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, portNumber1, qname)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port1.Port, qname)), }, }) }) }) - When("service is in local and remote clusters, and remote cluster is requested", func() { + When("a service is in local and remote clusters and the remote cluster is requested", func() { qname := fmt.Sprintf("%s.%s.%s.svc.clusterset.local.", clusterID2, service1, namespace1) + It("should succeed and write remote cluster's IP as A record response", func() { t.executeTestCase(rec, test.Case{ Qname: qname, @@ -856,24 +841,27 @@ func testLocalService() { }) }) - It("should succeed and write remote cluster's IP as SRV record response", func() { + It("should succeed and write the remote cluster's ports as SRV record response", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, Answer: []dns.RR{ - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, portNumber1, qname)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port2.Port, qname)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port1.Port, qname)), }, }) }) }) - When("service is in local and remote clusters, and local has no active endpoints", func() { - JustBeforeEach(func() { - t.mockEs.endpointStatusMap[clusterID] = false + When("service is in local and remote clusters and local has no active endpoints", func() { + BeforeEach(func() { + t.lh.Resolver.PutEndpointSlice(newEndpointSlice(namespace1, service1, clusterID, []mcsv1a1.ServicePort{port1})) }) + qname := fmt.Sprintf("%s.%s.svc.clusterset.local.", service1, namespace1) - It("should succeed and write remote cluster's IP as A record response", func() { + + It("should succeed and write the remote cluster's IP as A record response", func() { t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeA, @@ -883,16 +871,6 @@ func testLocalService() { }, }) }) - It("should succeed and write remote cluster's IP as SRV record response", func() { - t.executeTestCase(rec, test.Case{ - Qname: qname, - Qtype: dns.TypeSRV, - Rcode: dns.RcodeSuccess, - Answer: []dns.RR{ - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, portNumber1, qname)), - }, - }) - }) }) } @@ -904,87 +882,78 @@ func testSRVMultiplePorts() { BeforeEach(func() { t = newHandlerTestDriver() - t.mockCs.clusterStatusMap[clusterID] = true - t.mockEs.endpointStatusMap[clusterID] = true - t.mockCs.localClusterID = clusterID - - localSI := newServiceImport(namespace1, service1, clusterID, serviceIP, portName1, portNumber1, protocol1, mcsv1a1.ClusterSetIP) - localSI.Spec.Ports = append(localSI.Spec.Ports, mcsv1a1.ServicePort{ - Name: portName2, - Protocol: protocol2, - Port: portNumber2, - }) + t.mockCs.ConnectClusterID(clusterID) - t.lh.ServiceImports.Put(localSI) + t.lh.Resolver.PutServiceImport(newServiceImport(namespace1, service1, clusterID, serviceIP, mcsv1a1.ClusterSetIP, port1, port2)) + + t.lh.Resolver.PutEndpointSlice(newEndpointSlice(namespace1, service1, clusterID, []mcsv1a1.ServicePort{port1, port2}, + newEndpoint(endpointIP, "", true))) rec = dnstest.NewRecorder(&test.ResponseWriter{}) }) - When("DNS query of type SRV", func() { - It("without portName should return all the ports", func() { + Context("a DNS query of type SRV", func() { + Specify("without a port name should return all the ports", func() { qname := fmt.Sprintf("%s.%s.svc.clusterset.local.", service1, namespace1) + t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, Answer: []dns.RR{ - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, portNumber2, qname)), - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, portNumber1, qname)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port2.Port, qname)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port1.Port, qname)), }, }) }) - It("with HTTP portname should return TCP port", func() { - qname := fmt.Sprintf("%s.%s.%s.%s.svc.clusterset.local.", portName1, protocol1, service1, namespace1) + + Specify("with a port name requested should return only that port", func() { + qname := fmt.Sprintf("%s.%s.%s.%s.svc.clusterset.local.", port1.Name, port1.Protocol, service1, namespace1) + t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, Answer: []dns.RR{ - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.svc.clusterset.local.", qname, portNumber1, service1, namespace1)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.svc.clusterset.local.", qname, port1.Port, service1, namespace1)), }, }) - }) - It("with DNS portname should return UDP port", func() { - qname := fmt.Sprintf("%s.%s.%s.%s.svc.clusterset.local.", portName2, protocol2, service1, namespace1) + qname = fmt.Sprintf("%s.%s.%s.%s.svc.clusterset.local.", port2.Name, port2.Protocol, service1, namespace1) + t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, Answer: []dns.RR{ - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.svc.clusterset.local.", qname, portNumber2, service1, namespace1)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.svc.clusterset.local.", qname, port2.Port, service1, namespace1)), }, }) }) - Context("with DNS cluster name", func() { - JustBeforeEach(func() { - t.lh.ServiceImports.Put(newServiceImport(namespace1, service1, clusterID2, serviceIP2, portName1, portNumber1, - protocol1, mcsv1a1.ClusterSetIP)) - }) + Specify("with a DNS cluster name requested should return all the ports from the cluster", func() { + qname := fmt.Sprintf("%s.%s.%s.svc.clusterset.local.", clusterID, service1, namespace1) - It("should return all the ports from the cluster", func() { - qname := fmt.Sprintf("%s.%s.%s.svc.clusterset.local.", clusterID, service1, namespace1) - t.executeTestCase(rec, test.Case{ - Qname: qname, - Qtype: dns.TypeSRV, - Rcode: dns.RcodeSuccess, - Answer: []dns.RR{ - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, portNumber2, qname)), - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, portNumber1, qname)), - }, - }) + t.executeTestCase(rec, test.Case{ + Qname: qname, + Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port2.Port, qname)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port1.Port, qname)), + }, }) }) - It("with HTTP portname should return TCP port with underscore prefix", func() { - qname := fmt.Sprintf("_%s._%s.%s.%s.svc.clusterset.local.", portName1, protocol1, service1, namespace1) + Specify("with a port name requested with underscore prefix should return the port", func() { + qname := fmt.Sprintf("_%s._%s.%s.%s.svc.clusterset.local.", port1.Name, port1.Protocol, service1, namespace1) + t.executeTestCase(rec, test.Case{ Qname: qname, Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, Answer: []dns.RR{ - test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.svc.clusterset.local.", qname, portNumber1, service1, namespace1)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.svc.clusterset.local.", qname, port1.Port, service1, namespace1)), }, }) }) @@ -992,27 +961,22 @@ func testSRVMultiplePorts() { } type handlerTestDriver struct { - mockCs *MockClusterStatus - mockEs *MockEndpointStatus + mockCs *fakecs.ClusterStatus lh *lighthouse.Lighthouse } func newHandlerTestDriver() *handlerTestDriver { t := &handlerTestDriver{ - mockCs: NewMockClusterStatus(), - mockEs: NewMockEndpointStatus(), + mockCs: fakecs.NewClusterStatus(""), } t.lh = &lighthouse.Lighthouse{ - Zones: []string{"clusterset.local."}, - EndpointSlices: setupEndpointSliceMap(), - ClusterStatus: t.mockCs, - EndpointsStatus: t.mockEs, - TTL: uint32(5), + Zones: []string{"clusterset.local."}, + ClusterStatus: t.mockCs, + Resolver: resolver.New(t.mockCs, fake.NewSimpleDynamicClient(scheme.Scheme)), + TTL: uint32(5), } - t.initServiceImportMap() - return t } @@ -1029,21 +993,9 @@ func (t *handlerTestDriver) executeTestCase(rec *dnstest.Recorder, tc test.Case) } } -func (t *handlerTestDriver) initServiceImportMap() { - t.lh.ServiceImports = serviceimport.NewMap(localClusterID) -} - -func setupEndpointSliceMap() *endpointslice.Map { - esMap := endpointslice.NewMap(localClusterID, fakeKubeClient.NewSimpleClientset()) - esMap.Put(newEndpointSlice(namespace1, service1, clusterID, portName1, []string{hostName1}, []string{endpointIP}, portNumber1, protocol1)) - - return esMap -} - //nolint:unparam // `name` always receives `service1'. -func newServiceImport(namespace, name, clusterID, serviceIP, portName string, - portNumber int32, protocol v1.Protocol, siType mcsv1a1.ServiceImportType, -) *mcsv1a1.ServiceImport { +func newServiceImport(namespace, name, clusterID, serviceIP string, siType mcsv1a1.ServiceImportType, + ports ...mcsv1a1.ServicePort) *mcsv1a1.ServiceImport { return &mcsv1a1.ServiceImport{ ObjectMeta: metav1.ObjectMeta{ Name: name, @@ -1055,15 +1007,9 @@ func newServiceImport(namespace, name, clusterID, serviceIP, portName string, }, }, Spec: mcsv1a1.ServiceImportSpec{ - Type: siType, - IPs: []string{serviceIP}, - Ports: []mcsv1a1.ServicePort{ - { - Name: portName, - Protocol: protocol, - Port: portNumber, - }, - }, + Type: siType, + IPs: []string{serviceIP}, + Ports: ports, }, Status: mcsv1a1.ServiceImportStatus{ Clusters: []mcsv1a1.ClusterStatus{ @@ -1076,22 +1022,21 @@ func newServiceImport(namespace, name, clusterID, serviceIP, portName string, } //nolint:unparam // `namespace` always receives `namespace1`. -func newEndpointSlice(namespace, name, clusterID, portName string, hostName, endpointIPs []string, portNumber int32, - protocol v1.Protocol, -) *discovery.EndpointSlice { - endpoints := make([]discovery.Endpoint, len(endpointIPs)) - - for i := range endpointIPs { - endpoint := discovery.Endpoint{ - Addresses: []string{endpointIPs[i]}, - Hostname: &hostName[i], +func newEndpointSlice(namespace, name, clusterID string, ports []mcsv1a1.ServicePort, + endpoints ...discovery.Endpoint) *discovery.EndpointSlice { + epPorts := make([]discovery.EndpointPort, len(ports)) + for i := range ports { + epPorts[i] = discovery.EndpointPort{ + Name: &ports[i].Name, + Protocol: &ports[i].Protocol, + Port: &ports[i].Port, + AppProtocol: ports[i].AppProtocol, } - endpoints[i] = endpoint } return &discovery.EndpointSlice{ ObjectMeta: metav1.ObjectMeta{ - Name: name, + Name: name + "-" + namespace + "-" + clusterID, Namespace: namespace, Labels: map[string]string{ discovery.LabelManagedBy: constants.LabelValueManagedBy, @@ -1101,13 +1046,15 @@ func newEndpointSlice(namespace, name, clusterID, portName string, hostName, end }, }, AddressType: discovery.AddressTypeIPv4, + Ports: epPorts, Endpoints: endpoints, - Ports: []discovery.EndpointPort{ - { - Name: &portName, - Protocol: &protocol, - Port: &portNumber, - }, - }, + } +} + +func newEndpoint(address, hostname string, ready bool) discovery.Endpoint { + return discovery.Endpoint{ + Addresses: []string{address}, + Hostname: &hostname, + Conditions: discovery.EndpointConditions{Ready: &ready}, } } diff --git a/coredns/plugin/lighthouse.go b/coredns/plugin/lighthouse.go index ecb180bc1..587278171 100644 --- a/coredns/plugin/lighthouse.go +++ b/coredns/plugin/lighthouse.go @@ -25,8 +25,7 @@ import ( "github.com/coredns/coredns/plugin/pkg/fall" clog "github.com/coredns/coredns/plugin/pkg/log" "github.com/go-logr/logr" - "github.com/submariner-io/lighthouse/coredns/endpointslice" - "github.com/submariner-io/lighthouse/coredns/serviceimport" + "github.com/submariner-io/lighthouse/coredns/resolver" logf "sigs.k8s.io/controller-runtime/pkg/log" ) @@ -43,23 +42,12 @@ var errInvalidRequest = errors.New("invalid query name") var log = clog.NewWithPlugin(PluginName) type Lighthouse struct { - Next plugin.Handler - Fall fall.F - Zones []string - TTL uint32 - ServiceImports *serviceimport.Map - EndpointSlices *endpointslice.Map - ClusterStatus ClusterStatus - EndpointsStatus EndpointsStatus -} - -type ClusterStatus interface { - IsConnected(clusterID string) bool - LocalClusterID() string -} - -type EndpointsStatus interface { - IsHealthy(name, namespace, clusterID string) bool + Next plugin.Handler + Fall fall.F + Zones []string + TTL uint32 + ClusterStatus resolver.ClusterStatus + Resolver *resolver.Interface } var _ plugin.Handler = &Lighthouse{} diff --git a/coredns/plugin/record.go b/coredns/plugin/record.go index 55f63cc80..bbdbe9052 100644 --- a/coredns/plugin/record.go +++ b/coredns/plugin/record.go @@ -24,11 +24,11 @@ import ( "github.com/coredns/coredns/request" "github.com/miekg/dns" - "github.com/submariner-io/lighthouse/coredns/serviceimport" + "github.com/submariner-io/lighthouse/coredns/resolver" "sigs.k8s.io/mcs-api/pkg/apis/v1alpha1" ) -func (lh *Lighthouse) createARecords(dnsrecords []serviceimport.DNSRecord, state *request.Request) []dns.RR { +func (lh *Lighthouse) createARecords(dnsrecords []resolver.DNSRecord, state *request.Request) []dns.RR { records := make([]dns.RR, 0) for _, record := range dnsrecords { @@ -42,7 +42,7 @@ func (lh *Lighthouse) createARecords(dnsrecords []serviceimport.DNSRecord, state return records } -func (lh *Lighthouse) createSRVRecords(dnsrecords []serviceimport.DNSRecord, state *request.Request, pReq *recordRequest, zone string, +func (lh *Lighthouse) createSRVRecords(dnsrecords []resolver.DNSRecord, state *request.Request, pReq *recordRequest, zone string, isHeadless bool, ) []dns.RR { var records []dns.RR @@ -95,10 +95,3 @@ func (lh *Lighthouse) createSRVRecords(dnsrecords []serviceimport.DNSRecord, sta return records } - -func (lh *Lighthouse) getClusterIPForSvc(pReq *recordRequest) (*serviceimport.DNSRecord, bool) { - localClusterID := lh.ClusterStatus.LocalClusterID() - - return lh.ServiceImports.GetIP(pReq.namespace, pReq.service, pReq.cluster, localClusterID, lh.ClusterStatus.IsConnected, - lh.EndpointsStatus.IsHealthy) -} diff --git a/coredns/plugin/setup.go b/coredns/plugin/setup.go index 8e12f0dfe..21658546c 100644 --- a/coredns/plugin/setup.go +++ b/coredns/plugin/setup.go @@ -26,11 +26,17 @@ import ( "github.com/coredns/coredns/core/dnsserver" "github.com/coredns/coredns/plugin" "github.com/pkg/errors" - "github.com/submariner-io/lighthouse/coredns/endpointslice" + "github.com/submariner-io/admiral/pkg/watcher" "github.com/submariner-io/lighthouse/coredns/gateway" - "github.com/submariner-io/lighthouse/coredns/serviceimport" - "k8s.io/client-go/kubernetes" + "github.com/submariner-io/lighthouse/coredns/resolver" + discovery "k8s.io/api/discovery/v1" + "k8s.io/apimachinery/pkg/api/meta" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/kubernetes/scheme" + "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" + mcsv1a1 "sigs.k8s.io/mcs-api/pkg/apis/v1alpha1" ) var ( @@ -38,12 +44,23 @@ var ( kubeconfig string ) -// Hook for unit tests. -var buildKubeConfigFunc = clientcmd.BuildConfigFromFlags +// Hooks for unit tests. +var ( + buildKubeConfigFunc = clientcmd.BuildConfigFromFlags + + newDynamicClient = func(c *rest.Config) (dynamic.Interface, error) { + return dynamic.NewForConfig(c) + } + + restMapper meta.RESTMapper +) // init registers this plugin within the Caddy plugin framework. It uses "example" as the // name, and couples it to the Action "setup". func init() { + utilruntime.Must(mcsv1a1.AddToScheme(scheme.Scheme)) + utilruntime.Must(discovery.AddToScheme(scheme.Scheme)) + caddy.RegisterPlugin(PluginName, caddy.Plugin{ ServerType: "dns", Action: setupLighthouse, @@ -76,40 +93,39 @@ func lighthouseParse(c *caddy.Controller) (*Lighthouse, error) { gwController := gateway.NewController() - err = gwController.Start(cfg) + localClient, err := newDynamicClient(cfg) if err != nil { - return nil, errors.Wrap(err, "error starting the Gateway controller") + return nil, errors.Wrap(err, "error creating local client") } - siMap := serviceimport.NewMap(gwController.LocalClusterID()) - siController := serviceimport.NewController(siMap) + lh := &Lighthouse{ + TTL: defaultTTL, + ClusterStatus: gwController, + Resolver: resolver.New(gwController, localClient), + } - err = siController.Start(cfg) + err = gwController.Start(localClient) if err != nil { - return nil, errors.Wrap(err, "error starting the ServiceImport controller") + return nil, errors.Wrap(err, "error starting the Gateway controller") } - kubeClient := kubernetes.NewForConfigOrDie(cfg) - epMap := endpointslice.NewMap(gwController.LocalClusterID(), kubeClient) - epController := endpointslice.NewController(epMap) + resolverController := resolver.NewController(lh.Resolver) - err = epController.Start(cfg) + err = resolverController.Start(watcher.Config{ + RestConfig: cfg, + Client: localClient, + RestMapper: restMapper, + }) if err != nil { - return nil, errors.Wrap(err, "error starting the EndpointSlice controller") + return nil, errors.Wrap(err, "error starting the resolver controller") } c.OnShutdown(func() error { - siController.Stop() - epController.Stop() gwController.Stop() + resolverController.Stop() return nil }) - lh := &Lighthouse{ - TTL: defaultTTL, ServiceImports: siMap, ClusterStatus: gwController, EndpointSlices: epMap, - EndpointsStatus: epController, - } - // Changed `for` to `if` to satisfy golint: // SA4004: the surrounding loop is unconditionally terminated (staticcheck) if c.Next() { diff --git a/coredns/plugin/setup_internal_test.go b/coredns/plugin/setup_internal_test.go index 9ac6fc297..66f8c5775 100644 --- a/coredns/plugin/setup_internal_test.go +++ b/coredns/plugin/setup_internal_test.go @@ -28,18 +28,14 @@ import ( "github.com/miekg/dns" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "github.com/submariner-io/lighthouse/coredns/endpointslice" - "github.com/submariner-io/lighthouse/coredns/gateway" - "github.com/submariner-io/lighthouse/coredns/serviceimport" - "k8s.io/apimachinery/pkg/runtime" + "github.com/submariner-io/admiral/pkg/syncer/test" + discovery "k8s.io/api/discovery/v1" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/dynamic" fakeClient "k8s.io/client-go/dynamic/fake" - "k8s.io/client-go/kubernetes" - fakeKubeClient "k8s.io/client-go/kubernetes/fake" + "k8s.io/client-go/kubernetes/scheme" "k8s.io/client-go/rest" - mcsClientset "sigs.k8s.io/mcs-api/pkg/client/clientset/versioned" - fakeMCSClientset "sigs.k8s.io/mcs-api/pkg/client/clientset/versioned/fake" + mcsv1a1 "sigs.k8s.io/mcs-api/pkg/apis/v1alpha1" ) type fakeHandler struct{} @@ -66,24 +62,18 @@ var _ = Describe("Plugin setup", func() { Resource: "submariners", } - gateway.NewClientset = func(c *rest.Config) (dynamic.Interface, error) { - return fakeClient.NewSimpleDynamicClientWithCustomListKinds(runtime.NewScheme(), map[schema.GroupVersionResource]string{ + newDynamicClient = func(c *rest.Config) (dynamic.Interface, error) { + return fakeClient.NewSimpleDynamicClientWithCustomListKinds(scheme.Scheme, map[schema.GroupVersionResource]string{ gatewaysGVR: "GatewayList", submarinersGVR: "SubmarinersList", }), nil } - serviceimport.NewClientset = func(kubeConfig *rest.Config) (mcsClientset.Interface, error) { - return fakeMCSClientset.NewSimpleClientset(), nil - } - - endpointslice.NewClientset = func(kubeConfig *rest.Config) (kubernetes.Interface, error) { - return fakeKubeClient.NewSimpleClientset(), nil - } + restMapper = test.GetRESTMapperFor(&discovery.EndpointSlice{}, &mcsv1a1.ServiceImport{}) }) AfterEach(func() { - gateway.NewClientset = nil + newDynamicClient = nil }) Context("Parsing correct configurations", testCorrectConfig) diff --git a/coredns/resolver/fake/cluster_status.go b/coredns/resolver/fake/cluster_status.go index 5931fc57c..e216b1853 100644 --- a/coredns/resolver/fake/cluster_status.go +++ b/coredns/resolver/fake/cluster_status.go @@ -69,3 +69,10 @@ func (c *ClusterStatus) DisconnectClusterID(clusterID string) { c.connectedClusterIDs.Delete(clusterID) } + +func (c *ClusterStatus) ConnectClusterID(clusterID string) { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.connectedClusterIDs.Insert(clusterID) +}