Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[client] Support dns upstream failover for nameserver groups with same match domain #3178

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions client/internal/dns/handler_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
const (
PriorityDNSRoute = 100
PriorityMatchDomain = 50
PriorityDefault = 0
PriorityDefault = 1
)

type SubdomainMatcher interface {
Expand All @@ -26,7 +26,6 @@ type HandlerEntry struct {
Pattern string
OrigPattern string
IsWildcard bool
StopHandler handlerWithStop
MatchSubdomains bool
}

Expand Down Expand Up @@ -64,7 +63,7 @@ func (w *ResponseWriterChain) GetOrigPattern() string {
}

// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) {
c.mu.Lock()
defer c.mu.Unlock()

Expand All @@ -78,9 +77,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
// First remove any existing handler with same pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- {
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
if c.handlers[i].StopHandler != nil {
c.handlers[i].StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
break
}
Expand All @@ -101,7 +97,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
Pattern: pattern,
OrigPattern: origPattern,
IsWildcard: isWildcard,
StopHandler: stopHandler,
MatchSubdomains: matchSubdomains,
}

Expand Down Expand Up @@ -129,9 +124,6 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
if entry.StopHandler != nil {
entry.StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
return
}
Expand Down Expand Up @@ -167,8 +159,8 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if log.IsLevelEnabled(log.TraceLevel) {
log.Tracef("current handlers (%d):", len(handlers))
for _, h := range handlers {
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d",
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority)
}
}

Expand All @@ -193,13 +185,13 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}

if !matched {
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false",
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard)
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false",
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority)
continue
}

log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains)
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)

chainWriter := &ResponseWriterChain{
ResponseWriter: w,
Expand Down
26 changes: 13 additions & 13 deletions client/internal/dns/handler_chain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
dnsRouteHandler := &nbdns.MockHandler{}

// Setup handlers with different priorities
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain)
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)

// Create test request
r := new(dns.Msg)
Expand Down Expand Up @@ -138,7 +138,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
pattern = "*." + tt.handlerDomain[2:]
}

chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil)
chain.AddHandler(pattern, handler, nbdns.PriorityDefault)

r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA)
Expand Down Expand Up @@ -253,7 +253,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
}

chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority)
}

// Create and execute request
Expand All @@ -280,9 +280,9 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
handler3 := &nbdns.MockHandler{}

// Add handlers in priority order
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain)
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)

// Create test request
r := new(dns.Msg)
Expand Down Expand Up @@ -416,7 +416,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
if op.action == "add" {
handler := &nbdns.MockHandler{}
handlers[op.priority] = handler
chain.AddHandler(op.pattern, handler, op.priority, nil)
chain.AddHandler(op.pattern, handler, op.priority)
} else {
chain.RemoveHandler(op.pattern, op.priority)
}
Expand Down Expand Up @@ -471,9 +471,9 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
r.SetQuestion(testQuery, dns.TypeA)

// Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil)
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)

// Test 1: Initial state with all three handlers
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
Expand Down Expand Up @@ -653,7 +653,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
handler = mockHandler
}

chain.AddHandler(pattern, handler, h.priority, nil)
chain.AddHandler(pattern, handler, h.priority)
}

// Execute request
Expand Down
7 changes: 6 additions & 1 deletion client/internal/dns/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@ func (d *localResolver) String() string {
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
}

// ID returns the unique handler ID
func (d *localResolver) id() handlerID {
return "local-resolver"
}

// ServeDNS handles a DNS request
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) > 0 {
log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
}

replyMessage := &dns.Msg{}
Expand Down
Loading
Loading