From f49cd000169649d1eda2097f1fbac876c396d5d7 Mon Sep 17 00:00:00 2001 From: "Dr. Adedayo Adetoye" Date: Sun, 19 Apr 2020 13:46:23 +0100 Subject: [PATCH] Fix: refactoring and tests --- pkg/model/models.go | 5 ++-- pkg/tlsscan.go | 8 +++--- pkg/tlsscan_test.go | 63 +++++++++++++++++++-------------------------- 3 files changed, 33 insertions(+), 43 deletions(-) diff --git a/pkg/model/models.go b/pkg/model/models.go index 63df86a..1c61223 100644 --- a/pkg/model/models.go +++ b/pkg/model/models.go @@ -1004,7 +1004,8 @@ type HumanScanResult struct { SecureRenegotiationSupportedByProtocol map[string]bool CipherSuiteByProtocol map[string][]string // ServerHelloMessageByProtocolByCipher map[string]map[string]ServerHelloMessage - CertificatesPerProtocol map[string][]HumanCertificate + CertificatesPerProtocol map[string][]HumanCertificate + CertificatesWithChainIssue map[string]bool // KeyExchangeByProtocolByCipher map[string]map[string]ServerKeyExchangeMsg IsSTARTLS bool IsSSH bool @@ -1105,7 +1106,6 @@ func (s ScanResult) ToHumanScanResult() (out HumanScanResult) { } out.CipherSuiteByProtocol[tlsdefs.TLSVersionMap[k]] = ciphers } - out.CertificatesPerProtocol = make(map[string][]HumanCertificate) for p, c := range s.CertificatesPerProtocol { certs, err := c.GetCertificates() @@ -1157,6 +1157,7 @@ func (s ScanResult) ToHumanScanResult() (out HumanScanResult) { } } + out.CertificatesWithChainIssue = s.CertificatesWithChainIssue out.IsSTARTLS = s.IsSTARTLS out.IsSSH = s.IsSSH out.SupportsTLSFallbackSCSV = s.SupportsTLSFallbackSCSV diff --git a/pkg/tlsscan.go b/pkg/tlsscan.go index 4d00f5f..44d7a5f 100644 --- a/pkg/tlsscan.go +++ b/pkg/tlsscan.go @@ -38,7 +38,7 @@ func ScanCIDRTLS(cidr string, config tlsmodel.ScanConfig) []tlsmodel.ScanResult scan := make(map[string]tlsmodel.ScanResult) results := []<-chan tlsmodel.ScanResult{} results = append(results, scanCIDRTLS(cidr, config)) - for result := range MergeResultChannels(results...) { + for result := range mergeResultChannels(results...) { key := result.Server + result.Port if _, present := scan[key]; !present { scan[key] = result @@ -125,7 +125,7 @@ func scanCIDRTLS(cidr string, config tlsmodel.ScanConfig) <-chan tlsmodel.ScanRe } } } - for res := range MergeResultChannels(resultChannels...) { + for res := range mergeResultChannels(resultChannels...) { res.HostName = originalDomain scanResults <- res } @@ -180,8 +180,8 @@ func mergeACKChannels(ackChannels ...<-chan portscan.PortACK) <-chan portscan.Po return out } -//MergeResultChannels as suggested -func MergeResultChannels(channels ...<-chan tlsmodel.ScanResult) <-chan tlsmodel.ScanResult { +//mergeResultChannels as suggested +func mergeResultChannels(channels ...<-chan tlsmodel.ScanResult) <-chan tlsmodel.ScanResult { var wg sync.WaitGroup out := make(chan tlsmodel.ScanResult) output := func(c <-chan tlsmodel.ScanResult) { diff --git a/pkg/tlsscan_test.go b/pkg/tlsscan_test.go index 7bf41d3..a8bf8b5 100644 --- a/pkg/tlsscan_test.go +++ b/pkg/tlsscan_test.go @@ -1,6 +1,5 @@ package tlsaudit -//TODO implement tests import ( "strings" "testing" @@ -9,35 +8,25 @@ import ( ) var ( - config = tlsmodel.ScanConfig{} + config = tlsmodel.ScanConfig{ + Timeout: 5, + } ) -// func TestIncompleteChain(t *testing.T) { -// for scan := range ScanCIDRTLS("incomplete-chain.badssl.com:443", config) { -// if len(scan.CertificatesWithChainIssue) == 0 { -// t.Errorf("Expected to find a chain issue") -// } -// } -// } +func TestIncompleteChain(t *testing.T) { + for _, scan := range ScanCIDRTLS("incomplete-chain.badssl.com:443", config) { + hs := scan.ToHumanScanResult() + if len(hs.CertificatesWithChainIssue) == 0 { + t.Errorf("Expected to find a chain issue %#v", hs) + } + } +} func TestRSA8192(t *testing.T) { - // results := []<-chan tlsmodel.ScanResult{} - // scans := make(map[string]tlsmodel.ScanResult) - - // results = append(results, ScanCIDRTLS("rsa8192.badssl.com:443", config)) - // for result := range MergeResultChannels(results...) { - // key := result.Server + result.Port - // if _, present := scans[key]; !present { - // scans[key] = result - // } - // } - - t.Logf("\nStarted scan\n") for _, scan := range ScanCIDRTLS("rsa8192.badssl.com:443", config) { - t.Log("Got a scan") for _, certChain := range scan.ToHumanScanResult().CertificatesPerProtocol { cert := certChain[0] - if cert.PublicKeyAlgorithm != "RSAs" { + if cert.PublicKeyAlgorithm != "RSA" { t.Errorf("Expecting an RSA public key algorithm but got %s", cert.PublicKeyAlgorithm) } kl := strings.Split(cert.Key, " ")[0] @@ -48,17 +37,17 @@ func TestRSA8192(t *testing.T) { } } -// func TestECDSA384(t *testing.T) { -// for scan := range ScanCIDRTLS("ecc384.badssl.com:443", config) { -// for _, certChain := range scan.ToHumanScanResult().CertificatesPerProtocol { -// cert := certChain[0] -// if cert.PublicKeyAlgorithm != "ECDSA" { -// t.Errorf("Expecting an ECDSA public key algorithm but got %s", cert.PublicKeyAlgorithm) -// } -// kl := strings.Split(cert.Key, " ")[1] -// if kl != "384" { -// t.Errorf("Expecting cert key length of 384, but got %s", kl) -// } -// } -// } -// } +func TestECDSA384(t *testing.T) { + for _, scan := range ScanCIDRTLS("ecc384.badssl.com:443", config) { + for _, certChain := range scan.ToHumanScanResult().CertificatesPerProtocol { + cert := certChain[0] + if cert.PublicKeyAlgorithm != "ECDSA" { + t.Errorf("Expecting an ECDSA public key algorithm but got %s", cert.PublicKeyAlgorithm) + } + kl := strings.Split(cert.Key, " ")[1] + if kl != "384" { + t.Errorf("Expecting cert key length of 384, but got %s", kl) + } + } + } +}