diff --git a/pkg/output/stats/stats.go b/pkg/output/stats/stats.go index 1ec99816d7..1e030a88da 100644 --- a/pkg/output/stats/stats.go +++ b/pkg/output/stats/stats.go @@ -6,14 +6,13 @@ package stats import ( _ "embed" - "encoding/json" "fmt" - "regexp" "sort" "strconv" "sync/atomic" "github.com/logrusorgru/aurora" + "github.com/projectdiscovery/nuclei/v3/pkg/output/stats/waf" mapsutil "github.com/projectdiscovery/utils/maps" ) @@ -25,7 +24,7 @@ type Tracker struct { wafDetected *mapsutil.SyncLockMap[string, *atomic.Int32] // internal stuff - wafDetector *wafDetector + wafDetector *waf.WafDetector } // NewTracker creates a new Tracker instance. @@ -34,7 +33,7 @@ func NewTracker() *Tracker { statusCodes: mapsutil.NewSyncLockMap[string, *atomic.Int32](), errorCodes: mapsutil.NewSyncLockMap[string, *atomic.Int32](), wafDetected: mapsutil.NewSyncLockMap[string, *atomic.Int32](), - wafDetector: newWafDetector(), + wafDetector: waf.NewWafDetector(), } } @@ -92,7 +91,7 @@ func (t *Tracker) GetStats() *StatsOutput { return nil }) _ = t.wafDetected.Iterate(func(k string, v *atomic.Int32) error { - waf, ok := t.wafDetector.wafs[k] + waf, ok := t.wafDetector.GetWAF(k) if !ok { return nil } @@ -102,55 +101,6 @@ func (t *Tracker) GetStats() *StatsOutput { return stats } -type wafDetector struct { - wafs map[string]waf - regexCache map[string]*regexp.Regexp -} - -// waf represents a web application firewall definition -type waf struct { - Company string `json:"company"` - Name string `json:"name"` - Regex string `json:"regex"` -} - -// wafData represents the root JSON structure -type wafData struct { - WAFs map[string]waf `json:"wafs"` -} - -//go:embed regexes.json -var wafContentRegexes string - -func newWafDetector() *wafDetector { - var data wafData - if err := json.Unmarshal([]byte(wafContentRegexes), &data); err != nil { - panic("could not unmarshal waf content regexes: " + err.Error()) - } - - store := &wafDetector{ - wafs: data.WAFs, - regexCache: make(map[string]*regexp.Regexp), - } - - for id, waf := range store.wafs { - if waf.Regex == "" { - continue - } - store.regexCache[id] = regexp.MustCompile(waf.Regex) - } - return store -} - -func (d *wafDetector) DetectWAF(content string) (string, bool) { - for id, regex := range d.regexCache { - if regex.MatchString(content) { - return id, true - } - } - return "", false -} - // DisplayTopStats prints the most relevant statistics for CLI func (t *Tracker) DisplayTopStats(noColor bool) { stats := t.GetStats() diff --git a/pkg/output/stats/stats_test.go b/pkg/output/stats/stats_test.go index 78ce398c0a..2eec59f300 100644 --- a/pkg/output/stats/stats_test.go +++ b/pkg/output/stats/stats_test.go @@ -26,59 +26,11 @@ func TestTrackErrorKind(t *testing.T) { } } -func TestWAFDetection(t *testing.T) { - detector := newWafDetector() - if detector == nil { - t.Fatal("expected non-nil wafDetector") - } - - tests := []struct { - name string - content string - expectedWAF string - shouldMatch bool - }{ - { - name: "Cloudflare WAF", - content: "Attention Required! | Cloudflare", - expectedWAF: "cloudflare", - shouldMatch: true, - }, - { - name: "ModSecurity WAF", - content: "This error was generated by Mod_Security", - expectedWAF: "modsecurity", - shouldMatch: true, - }, - { - name: "No WAF", - content: "Regular response with no WAF signature", - expectedWAF: "", - shouldMatch: false, - }, - { - name: "Wordfence WAF", - content: "Generated by Wordfence", - expectedWAF: "wordfence", - shouldMatch: true, - }, - { - name: "Sucuri WAF", - content: "Access Denied - Sucuri Website Firewall", - expectedWAF: "sucuri", - shouldMatch: true, - }, - } +func TestTrackWaf_Detect(t *testing.T) { + tracker := NewTracker() - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - waf, matched := detector.DetectWAF(tt.content) - if matched != tt.shouldMatch { - t.Errorf("expected match=%v, got match=%v", tt.shouldMatch, matched) - } - if matched && waf != tt.expectedWAF { - t.Errorf("expected WAF=%s, got WAF=%s", tt.expectedWAF, waf) - } - }) + tracker.TrackWAFDetected("Attention Required! | Cloudflare") + if count, _ := tracker.wafDetected.Get("cloudflare"); count == nil || count.Load() != 1 { + t.Errorf("expected waf detected count to be 1, got %v", count) } } diff --git a/pkg/output/stats/regexes.json b/pkg/output/stats/waf/regexes.json similarity index 100% rename from pkg/output/stats/regexes.json rename to pkg/output/stats/waf/regexes.json diff --git a/pkg/output/stats/waf/waf.go b/pkg/output/stats/waf/waf.go new file mode 100644 index 0000000000..c454588461 --- /dev/null +++ b/pkg/output/stats/waf/waf.go @@ -0,0 +1,62 @@ +package waf + +import ( + _ "embed" + "encoding/json" + "log" + "regexp" +) + +type WafDetector struct { + wafs map[string]waf + regexCache map[string]*regexp.Regexp +} + +// waf represents a web application firewall definition +type waf struct { + Company string `json:"company"` + Name string `json:"name"` + Regex string `json:"regex"` +} + +// wafData represents the root JSON structure +type wafData struct { + WAFs map[string]waf `json:"wafs"` +} + +//go:embed regexes.json +var wafContentRegexes string + +func NewWafDetector() *WafDetector { + var data wafData + if err := json.Unmarshal([]byte(wafContentRegexes), &data); err != nil { + log.Printf("could not unmarshal waf content regexes: %s", err) + } + + store := &WafDetector{ + wafs: data.WAFs, + regexCache: make(map[string]*regexp.Regexp), + } + + for id, waf := range store.wafs { + if waf.Regex == "" { + continue + } + store.regexCache[id] = regexp.MustCompile(waf.Regex) + } + return store +} + +func (d *WafDetector) DetectWAF(content string) (string, bool) { + for id, regex := range d.regexCache { + if regex.MatchString(content) { + return id, true + } + } + return "", false +} + +func (d *WafDetector) GetWAF(id string) (waf, bool) { + waf, ok := d.wafs[id] + return waf, ok +} diff --git a/pkg/output/stats/waf/waf_test.go b/pkg/output/stats/waf/waf_test.go new file mode 100644 index 0000000000..0698b3a42c --- /dev/null +++ b/pkg/output/stats/waf/waf_test.go @@ -0,0 +1,60 @@ +package waf + +import "testing" + +func TestWAFDetection(t *testing.T) { + detector := NewWafDetector() + if detector == nil { + t.Fatal("expected non-nil wafDetector") + } + + tests := []struct { + name string + content string + expectedWAF string + shouldMatch bool + }{ + { + name: "Cloudflare WAF", + content: "Attention Required! | Cloudflare", + expectedWAF: "cloudflare", + shouldMatch: true, + }, + { + name: "ModSecurity WAF", + content: "This error was generated by Mod_Security", + expectedWAF: "modsecurity", + shouldMatch: true, + }, + { + name: "No WAF", + content: "Regular response with no WAF signature", + expectedWAF: "", + shouldMatch: false, + }, + { + name: "Wordfence WAF", + content: "Generated by Wordfence", + expectedWAF: "wordfence", + shouldMatch: true, + }, + { + name: "Sucuri WAF", + content: "Access Denied - Sucuri Website Firewall", + expectedWAF: "sucuri", + shouldMatch: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + waf, matched := detector.DetectWAF(tt.content) + if matched != tt.shouldMatch { + t.Errorf("expected match=%v, got match=%v", tt.shouldMatch, matched) + } + if matched && waf != tt.expectedWAF { + t.Errorf("expected WAF=%s, got WAF=%s", tt.expectedWAF, waf) + } + }) + } +}