diff --git a/dns_bench_test.go b/dns_bench_test.go index 664d78efb..b0c89d8f2 100644 --- a/dns_bench_test.go +++ b/dns_bench_test.go @@ -99,6 +99,15 @@ func BenchmarkMsgLengthOnlyQuestion(b *testing.B) { } } +func BenchmarkMsgLengthEscapedName(b *testing.B) { + msg := new(Msg) + msg.SetQuestion(`\1\2\3\4\5\6\7\8\9\0\1\2\3\4\5\6\7\8\9\0\1\2\3\4\5\6\7\8\9\0\1\2\3\4\5.\1\2\3\4\5\6\7\8.\1\2\3.`, TypeANY) + b.ResetTimer() + for i := 0; i < b.N; i++ { + msg.Len() + } +} + func BenchmarkPackDomainName(b *testing.B) { name1 := "12345678901234567890123456789012345.12345678.123." buf := make([]byte, len(name1)+1) diff --git a/length_test.go b/length_test.go index ed450a857..5d897e823 100644 --- a/length_test.go +++ b/length_test.go @@ -439,9 +439,38 @@ func TestMsgCompressLengthEscapingMatch(t *testing.T) { if err != nil { t.Error(err) } - // Len doesn't account for escaping when calculating the length *yet* so - // we're off by three here. This will be fixed in a follow up change. - if predicted != len(buf)+3 { + if predicted != len(buf) { + t.Fatalf("predicted compressed length is wrong: predicted %d, actual %d", predicted, len(buf)) + } +} + +func TestMsgLengthEscaped(t *testing.T) { + msg := new(Msg) + msg.SetQuestion(`\000\001\002.\003\004\005\006\007\008\009.\a\b\c.`, TypeA) + + predicted := msg.Len() + buf, err := msg.Pack() + if err != nil { + t.Error(err) + } + if predicted != len(buf) { + t.Fatalf("predicted compressed length is wrong: predicted %d, actual %d", predicted, len(buf)) + } +} + +func TestMsgCompressLengthEscaped(t *testing.T) { + msg := new(Msg) + msg.Compress = true + msg.SetQuestion("www.example.org.", TypeA) + msg.Answer = append(msg.Answer, &NS{Hdr: RR_Header{Name: `\000\001\002.example.org.`, Rrtype: TypeNS, Class: ClassINET}, Ns: `ns.\e\x\a\m\p\l\e.org.`}) + msg.Answer = append(msg.Answer, &NS{Hdr: RR_Header{Name: `www.\e\x\a\m\p\l\e.org.`, Rrtype: TypeNS, Class: ClassINET}, Ns: "ns.example.org."}) + + predicted := msg.Len() + buf, err := msg.Pack() + if err != nil { + t.Error(err) + } + if predicted != len(buf) { t.Fatalf("predicted compressed length is wrong: predicted %d, actual %d", predicted, len(buf)) } } diff --git a/msg.go b/msg.go index c3d465fc7..5d0969812 100644 --- a/msg.go +++ b/msg.go @@ -17,6 +17,7 @@ import ( "math/big" "math/rand" "strconv" + "strings" "sync" ) @@ -965,16 +966,40 @@ func domainNameLen(s string, off int, compression map[string]struct{}, compress return 1 } - nameLen := len(s) + 1 - if compression == nil { - return nameLen - } + escaped := strings.Contains(s, "\\") - if compress || off < maxCompressionOffset { + if compression != nil && (compress || off < maxCompressionOffset) { // compressionLenSearch will insert the entry into the compression // map if it doesn't contain it. if l, ok := compressionLenSearch(compression, s, off); ok && compress { - nameLen = l + 2 + if escaped { + return escapedNameLen(s[:l]) + 2 + } + + return l + 2 + } + } + + if escaped { + return escapedNameLen(s) + 1 + } + + return len(s) + 1 +} + +func escapedNameLen(s string) int { + nameLen := len(s) + for i := 0; i < len(s); i++ { + if s[i] != '\\' { + continue + } + + if i+3 < len(s) && isDigit(s[i+1]) && isDigit(s[i+2]) && isDigit(s[i+3]) { + nameLen -= 3 + i += 3 + } else { + nameLen-- + i++ } }