Skip to content

Commit

Permalink
Support dns upstream failover for nameserver groups with same match d…
Browse files Browse the repository at this point in the history
…omain
  • Loading branch information
lixmal committed Jan 13, 2025
1 parent 3fce848 commit c728b16
Show file tree
Hide file tree
Showing 11 changed files with 322 additions and 174 deletions.
20 changes: 6 additions & 14 deletions client/internal/dns/handler_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
const (
PriorityDNSRoute = 100
PriorityMatchDomain = 50
PriorityDefault = 0
PriorityDefault = 1
)

type SubdomainMatcher interface {
Expand All @@ -26,7 +26,6 @@ type HandlerEntry struct {
Pattern string
OrigPattern string
IsWildcard bool
StopHandler handlerWithStop
MatchSubdomains bool
}

Expand Down Expand Up @@ -64,7 +63,7 @@ func (w *ResponseWriterChain) GetOrigPattern() string {
}

// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) {
c.mu.Lock()
defer c.mu.Unlock()

Expand All @@ -78,9 +77,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
// First remove any existing handler with same pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- {
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
if c.handlers[i].StopHandler != nil {
c.handlers[i].StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
break
}
Expand All @@ -101,7 +97,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
Pattern: pattern,
OrigPattern: origPattern,
IsWildcard: isWildcard,
StopHandler: stopHandler,
MatchSubdomains: matchSubdomains,
}

Expand Down Expand Up @@ -129,9 +124,6 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
if entry.StopHandler != nil {
entry.StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
return
}
Expand Down Expand Up @@ -193,13 +185,13 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}

if !matched {
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false",
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard)
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false",
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority)
continue
}

log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains)
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)

chainWriter := &ResponseWriterChain{
ResponseWriter: w,
Expand Down
26 changes: 13 additions & 13 deletions client/internal/dns/handler_chain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
dnsRouteHandler := &nbdns.MockHandler{}

// Setup handlers with different priorities
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain)
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)

// Create test request
r := new(dns.Msg)
Expand Down Expand Up @@ -138,7 +138,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
pattern = "*." + tt.handlerDomain[2:]
}

chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil)
chain.AddHandler(pattern, handler, nbdns.PriorityDefault)

r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA)
Expand Down Expand Up @@ -253,7 +253,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
}

chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority)
}

// Create and execute request
Expand All @@ -280,9 +280,9 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
handler3 := &nbdns.MockHandler{}

// Add handlers in priority order
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain)
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)

// Create test request
r := new(dns.Msg)
Expand Down Expand Up @@ -416,7 +416,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
if op.action == "add" {
handler := &nbdns.MockHandler{}
handlers[op.priority] = handler
chain.AddHandler(op.pattern, handler, op.priority, nil)
chain.AddHandler(op.pattern, handler, op.priority)
} else {
chain.RemoveHandler(op.pattern, op.priority)
}
Expand Down Expand Up @@ -471,9 +471,9 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
r.SetQuestion(testQuery, dns.TypeA)

// Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil)
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)

// Test 1: Initial state with all three handlers
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
Expand Down Expand Up @@ -653,7 +653,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
handler = mockHandler
}

chain.AddHandler(pattern, handler, h.priority, nil)
chain.AddHandler(pattern, handler, h.priority)
}

// Execute request
Expand Down
7 changes: 6 additions & 1 deletion client/internal/dns/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@ func (d *localResolver) String() string {
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
}

// ID returns the unique handler ID
func (d *localResolver) id() handlerID {
return "local-resolver"
}

// ServeDNS handles a DNS request
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) > 0 {
log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
}

replyMessage := &dns.Msg{}
Expand Down
Loading

0 comments on commit c728b16

Please sign in to comment.