From af7c59b5c2c094882d0aa947f792c277a7672e55 Mon Sep 17 00:00:00 2001 From: yusing Date: Thu, 6 Feb 2025 05:50:03 +0800 Subject: [PATCH] add tests for rules.on --- internal/route/rules/on_test.go | 154 ++++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) diff --git a/internal/route/rules/on_test.go b/internal/route/rules/on_test.go index c5bdc8e8..fb7af458 100644 --- a/internal/route/rules/on_test.go +++ b/internal/route/rules/on_test.go @@ -1,10 +1,15 @@ package rules import ( + "encoding/base64" + "fmt" + "net/http" + "net/url" "testing" E "github.com/yusing/go-proxy/internal/error" . "github.com/yusing/go-proxy/internal/utils/testing" + "golang.org/x/crypto/bcrypt" ) func TestParseOn(t *testing.T) { @@ -122,3 +127,152 @@ func TestParseOn(t *testing.T) { }) } } + +type testCorrectness struct { + name string + checker string + input *http.Request + want bool +} + +func genCorrectnessTestCases(field string, genRequest func(k, v string) *http.Request) []testCorrectness { + return []testCorrectness{ + { + name: field + "_match", + checker: field + " foo bar", + input: genRequest("foo", "bar"), + want: true, + }, + { + name: field + "_no_match", + checker: field + " foo baz", + input: genRequest("foo", "bar"), + want: false, + }, + { + name: field + "_exists", + checker: field + " foo", + input: genRequest("foo", "abcd"), + want: true, + }, + { + name: field + "_not_exists", + checker: field + " foo", + input: genRequest("bar", "abcd"), + want: false, + }, + } +} + +func TestOnCorrectness(t *testing.T) { + tests := []testCorrectness{ + { + name: "method_match", + checker: "method GET", + input: &http.Request{Method: http.MethodGet}, + want: true, + }, + { + name: "method_no_match", + checker: "method GET", + input: &http.Request{Method: http.MethodPost}, + want: false, + }, + { + name: "path_exact_match", + checker: "path /example", + input: &http.Request{ + URL: &url.URL{Path: "/example"}, + }, + want: true, + }, + { + name: "path_wildcard_match", + checker: "path /example/*", + input: &http.Request{ + URL: &url.URL{Path: "/example/123"}, + }, + want: true, + }, + { + name: "remote_match", + checker: "remote 192.168.1.0/24", + input: &http.Request{ + RemoteAddr: "192.168.1.5", + }, + want: true, + }, + { + name: "remote_no_match", + checker: "remote 192.168.1.0/24", + input: &http.Request{ + RemoteAddr: "192.168.2.5", + }, + want: false, + }, + { + name: "basic_auth_correct", + checker: "basic_auth user " + string(E.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))), + input: &http.Request{ + Header: http.Header{ + "Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:password"))}, // "user:password" + }, + }, + want: true, + }, + { + name: "basic_auth_incorrect", + checker: "basic_auth user " + string(E.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))), + input: &http.Request{ + Header: http.Header{ + "Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:incorrect"))}, // "user:wrong" + }, + }, + want: false, + }, + } + + tests = append(tests, genCorrectnessTestCases("header", func(k, v string) *http.Request { + return &http.Request{ + Header: http.Header{k: []string{v}}} + })...) + tests = append(tests, genCorrectnessTestCases("query", func(k, v string) *http.Request { + return &http.Request{ + URL: &url.URL{ + RawQuery: fmt.Sprintf("%s=%s", k, v), + }, + } + })...) + tests = append(tests, genCorrectnessTestCases("cookie", func(k, v string) *http.Request { + return &http.Request{ + Header: http.Header{ + "Cookie": {fmt.Sprintf("%s=%s", k, v)}, + }, + } + })...) + tests = append(tests, genCorrectnessTestCases("form", func(k, v string) *http.Request { + return &http.Request{ + Form: url.Values{ + k: []string{v}, + }, + } + })...) + tests = append(tests, genCorrectnessTestCases("postform", func(k, v string) *http.Request { + return &http.Request{ + PostForm: url.Values{ + k: []string{v}, + }, + } + })...) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + on, err := parseOn(tt.checker) + ExpectNoError(t, err) + got := on.Check(Cache{}, tt.input) + if tt.want != got { + t.Errorf("want %v, got %v", tt.want, got) + } + }) + } +}