From f74c64a81473eb1f5714159407ba72d06d2f56d3 Mon Sep 17 00:00:00 2001 From: Harsh-2909 Date: Sun, 21 Jul 2024 13:15:50 +0530 Subject: [PATCH] feat: Adds functions to check for errors and response in header flags and adds tests for that. --- dns/dns_message_test.go | 5 ----- dns/header_flag.go | 18 ++++++++++++++++++ dns/header_flag_test.go | 24 ++++++++++++++++++++++++ network/client.go | 24 ++++++++++++++++++------ network/client_test.go | 2 +- 5 files changed, 61 insertions(+), 12 deletions(-) diff --git a/dns/dns_message_test.go b/dns/dns_message_test.go index e03c5db..d5d58ba 100644 --- a/dns/dns_message_test.go +++ b/dns/dns_message_test.go @@ -37,9 +37,4 @@ func TestDNSMessage(t *testing.T) { } assert.Equal(t, DNSMessage, *DNSMessageFromBytes(DNSMessageBytes)) }) - - // TODO: Write tests and code checks for the following in all the test files: - // - Whether the Header Flag parsed from response has proper flag to denote that it is a response - // - Whether the Header Flag parsed from response has proper flag to denote that it does not have an error - // - Whether the Resource Record checks for the type of record and parses it accordingly } diff --git a/dns/header_flag.go b/dns/header_flag.go index 6a3622b..2a5fca3 100644 --- a/dns/header_flag.go +++ b/dns/header_flag.go @@ -85,3 +85,21 @@ func HeaderFlagFromBytes(b []byte) *HeaderFlag { return HeaderFlagFromUint16(flag) } + +// HasError returns whether the HeaderFlag has an error. +// It checks the value of the RCode field. +func (hf *HeaderFlag) HasError() bool { + return hf.RCode != RCodeNoError +} + +// IsQuery returns whether the HeaderFlag is a query. +// It checks the value of the QR field. +func (hf *HeaderFlag) IsQuery() bool { + return !hf.QR +} + +// IsResponse returns whether the HeaderFlag is a response. +// It checks the value of the QR field. +func (hf *HeaderFlag) IsResponse() bool { + return hf.QR +} diff --git a/dns/header_flag_test.go b/dns/header_flag_test.go index 09f670a..9fae59e 100644 --- a/dns/header_flag_test.go +++ b/dns/header_flag_test.go @@ -30,4 +30,28 @@ func TestHeaderFlag(t *testing.T) { expected := NewHeaderFlag(false, 0, false, false, true, false, 0, 0) assert.Equal(t, expected, HeaderFlagFromBytes(flagBytes)) }) + + t.Run("Should check if the header flag has an error", func(t *testing.T) { + flag := NewHeaderFlag(false, 0, false, false, true, false, 0, 0) + assert.False(t, flag.HasError()) + + flag = NewHeaderFlag(false, 0, false, false, true, false, 0, 1) + assert.True(t, flag.HasError()) + }) + + t.Run("Should check if the header flag is a query", func(t *testing.T) { + flag := NewHeaderFlag(false, 0, false, false, true, false, 0, 0) + assert.True(t, flag.IsQuery()) + + flag = NewHeaderFlag(true, 0, false, false, true, false, 0, 1) + assert.False(t, flag.IsQuery()) + }) + + t.Run("Should check if the header flag is a response", func(t *testing.T) { + flag := NewHeaderFlag(false, 0, false, false, true, false, 0, 0) + assert.False(t, flag.IsResponse()) + + flag = NewHeaderFlag(true, 0, false, false, true, false, 0, 1) + assert.True(t, flag.IsResponse()) + }) } diff --git a/network/client.go b/network/client.go index d819073..d97a831 100644 --- a/network/client.go +++ b/network/client.go @@ -4,6 +4,7 @@ import ( "dns-resolver-go/dns" "fmt" "net" + "os" "time" ) @@ -112,6 +113,18 @@ func Resolve(domain string, questionType uint16) string { return "" } parsedResponse = dns.DNSMessageFromBytes(response) + fmt.Printf("parsedResponse:\n %+v\n\n", parsedResponse) + flags := dns.HeaderFlagFromUint16(parsedResponse.Header.Flags) + + if flags.HasError() { + fmt.Printf("The DNS server returned an error: %s\n", parsedResponse.Answers[0].RDataParsed) + os.Exit(1) + } + + if flags.IsQuery() { + fmt.Printf("The returned DNS message is not a response.\n") + os.Exit(1) + } if parsedResponse.Header.ANCount > 0 { fmt.Printf("\nNon-authoritative answer:\n") @@ -125,19 +138,18 @@ func Resolve(domain string, questionType uint16) string { } } break - } - - if parsedResponse.Header.ARCount > 0 { + } else if parsedResponse.Header.ARCount > 0 { if ip := getRecord(parsedResponse.AdditionalRRs); ip != "" { dnsServerIP = ip } continue - } - - if parsedResponse.Header.NSCount > 0 { + } else if parsedResponse.Header.NSCount > 0 { if nsDomain := getRecord(parsedResponse.AuthorityRRs); nsDomain != "" { dnsServerIP = Resolve(nsDomain, dns.TypeA) } + } else { + fmt.Printf("No answers found for %s\n", domain) + os.Exit(1) } } return parsedResponse.Answers[0].RDataParsed diff --git a/network/client_test.go b/network/client_test.go index 0316115..5254618 100644 --- a/network/client_test.go +++ b/network/client_test.go @@ -8,7 +8,7 @@ import ( ) func TestClient(t *testing.T) { - t.Run("Should create a client", func(t *testing.T) { + t.Run("Should check if the IDs match", func(t *testing.T) { queryMessage, _ := hex.DecodeString("00160100000100000000000003646e7306676f6f676c6503636f6d0000010001") response, _ := hex.DecodeString("00168080000100020000000003646e7306676f6f676c6503636f6d0000010001c00c0001000100000214000408080808c00c0001000100000214000408080404") wrongResponse := []byte{0, 20}