From 0950cf204124ee9b32f54bd8f232c57536102a59 Mon Sep 17 00:00:00 2001 From: Oleg Kovalov Date: Wed, 15 Jun 2022 00:12:57 +0200 Subject: [PATCH] Switch to netip package --- go.mod | 2 +- handler.go | 41 ++++++++++++++++++++++++++--------------- handler_test.go | 4 ++-- 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/go.mod b/go.mod index 2ef2124..d795dfa 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/cristalhq/ipfilterware -go 1.17 +go 1.18 diff --git a/handler.go b/handler.go index 5318ce2..8b1ea9c 100644 --- a/handler.go +++ b/handler.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/http" + "net/netip" "strings" "sync/atomic" ) @@ -39,9 +40,14 @@ func New(next http.Handler, cfg *Config) (*Handler, error) { // Wrap a given handler. func (h *Handler) Wrap(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ip := ipFromRequest(r) - filter := h.ipFilter.Load().(*ipFilter) + + ip, err := ipFromRequest(r) + if err != nil { + filter.forbiddenHandler.ServeHTTP(w, r) + return + } + if filter.isAllowed(ip) { next.ServeHTTP(w, r) } else { @@ -52,9 +58,14 @@ func (h *Handler) Wrap(next http.Handler) http.Handler { // ServeHTTP implements http.Handler interface. func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ip := ipFromRequest(r) - filter := h.ipFilter.Load().(*ipFilter) + + ip, err := ipFromRequest(r) + if err != nil { + filter.forbiddenHandler.ServeHTTP(w, r) + return + } + if filter.isAllowed(ip) { h.next.ServeHTTP(w, r) } else { @@ -73,7 +84,7 @@ func (h *Handler) Update(cfg *Config) error { } // IsAllowed reports whether given IP is allowed. -func (h *Handler) IsAllowed(ip net.IP) bool { +func (h *Handler) IsAllowed(ip netip.Addr) bool { filter := h.ipFilter.Load().(*ipFilter) return filter.isAllowed(ip) } @@ -87,7 +98,7 @@ var defaultForbiddenHandler = http.HandlerFunc(func(w http.ResponseWriter, r *ht }) type ipFilter struct { - allowedIP map[string]struct{} + allowedIP map[netip.Addr]struct{} allowedCIDR []*net.IPNet forbiddenHandler http.Handler } @@ -112,31 +123,31 @@ func newIPFilter(cfg *Config) (*ipFilter, error) { return ipf, nil } -func (ipf *ipFilter) isAllowed(ip net.IP) bool { - if ip == nil { +func (ipf *ipFilter) isAllowed(ip netip.Addr) bool { + if !ip.IsValid() { return false } - if _, ok := ipf.allowedIP[ip.String()]; ok { + if _, ok := ipf.allowedIP[ip]; ok { return true } for _, cidr := range ipf.allowedCIDR { - if cidr.Contains(ip) { + if cidr.Contains(net.ParseIP(ip.String())) { return true } } return false } -func ipFromRequest(r *http.Request) net.IP { +func ipFromRequest(r *http.Request) (netip.Addr, error) { ip := r.RemoteAddr if strings.IndexByte(r.RemoteAddr, byte(':')) >= 0 { ip, _, _ = net.SplitHostPort(r.RemoteAddr) } - return net.ParseIP(ip) + return netip.ParseAddr(ip) } -func parseIPWithCIDR(nets []string) (map[string]struct{}, []*net.IPNet, error) { - ips := make(map[string]struct{}, len(nets)) +func parseIPWithCIDR(nets []string) (map[netip.Addr]struct{}, []*net.IPNet, error) { + ips := make(map[netip.Addr]struct{}, len(nets)) cidrs := make([]*net.IPNet, 0, len(nets)) for _, n := range nets { @@ -145,7 +156,7 @@ func parseIPWithCIDR(nets []string) (map[string]struct{}, []*net.IPNet, error) { continue } if ip := net.ParseIP(n); ip != nil { - ips[n] = struct{}{} + ips[netip.MustParseAddr(n)] = struct{}{} continue } return nil, nil, fmt.Errorf("bad IP or CIDR: %q", n) diff --git a/handler_test.go b/handler_test.go index b564ca0..327dd1e 100644 --- a/handler_test.go +++ b/handler_test.go @@ -1,9 +1,9 @@ package ipfilterware import ( - "net" "net/http" "net/http/httptest" + "net/netip" "testing" ) @@ -52,7 +52,7 @@ func TestSingleIP(t *testing.T) { {"10.20.0.1", false}, } for i, tc := range testCases { - if f.IsAllowed(net.ParseIP(tc.ip)) != tc.isAllowed { + if f.IsAllowed(netip.MustParseAddr(tc.ip)) != tc.isAllowed { t.Errorf("[%d] ip %q must be %v", i, tc.ip, tc.isAllowed) } }