Skip to content

Commit

Permalink
Switch to netip package
Browse files Browse the repository at this point in the history
  • Loading branch information
cristaloleg committed Jun 14, 2022
1 parent 3002076 commit 0950cf2
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 18 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/cristalhq/ipfilterware

go 1.17
go 1.18
41 changes: 26 additions & 15 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/http"
"net/netip"
"strings"
"sync/atomic"
)
Expand Down Expand Up @@ -39,9 +40,14 @@ func New(next http.Handler, cfg *Config) (*Handler, error) {
// Wrap a given handler.
func (h *Handler) Wrap(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := ipFromRequest(r)

filter := h.ipFilter.Load().(*ipFilter)

ip, err := ipFromRequest(r)
if err != nil {
filter.forbiddenHandler.ServeHTTP(w, r)
return
}

if filter.isAllowed(ip) {
next.ServeHTTP(w, r)
} else {
Expand All @@ -52,9 +58,14 @@ func (h *Handler) Wrap(next http.Handler) http.Handler {

// ServeHTTP implements http.Handler interface.
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ip := ipFromRequest(r)

filter := h.ipFilter.Load().(*ipFilter)

ip, err := ipFromRequest(r)
if err != nil {
filter.forbiddenHandler.ServeHTTP(w, r)
return
}

if filter.isAllowed(ip) {
h.next.ServeHTTP(w, r)
} else {
Expand All @@ -73,7 +84,7 @@ func (h *Handler) Update(cfg *Config) error {
}

// IsAllowed reports whether given IP is allowed.
func (h *Handler) IsAllowed(ip net.IP) bool {
func (h *Handler) IsAllowed(ip netip.Addr) bool {
filter := h.ipFilter.Load().(*ipFilter)
return filter.isAllowed(ip)
}
Expand All @@ -87,7 +98,7 @@ var defaultForbiddenHandler = http.HandlerFunc(func(w http.ResponseWriter, r *ht
})

type ipFilter struct {
allowedIP map[string]struct{}
allowedIP map[netip.Addr]struct{}
allowedCIDR []*net.IPNet
forbiddenHandler http.Handler
}
Expand All @@ -112,31 +123,31 @@ func newIPFilter(cfg *Config) (*ipFilter, error) {
return ipf, nil
}

func (ipf *ipFilter) isAllowed(ip net.IP) bool {
if ip == nil {
func (ipf *ipFilter) isAllowed(ip netip.Addr) bool {
if !ip.IsValid() {
return false
}
if _, ok := ipf.allowedIP[ip.String()]; ok {
if _, ok := ipf.allowedIP[ip]; ok {
return true
}
for _, cidr := range ipf.allowedCIDR {
if cidr.Contains(ip) {
if cidr.Contains(net.ParseIP(ip.String())) {
return true
}
}
return false
}

func ipFromRequest(r *http.Request) net.IP {
func ipFromRequest(r *http.Request) (netip.Addr, error) {
ip := r.RemoteAddr
if strings.IndexByte(r.RemoteAddr, byte(':')) >= 0 {
ip, _, _ = net.SplitHostPort(r.RemoteAddr)
}
return net.ParseIP(ip)
return netip.ParseAddr(ip)
}

func parseIPWithCIDR(nets []string) (map[string]struct{}, []*net.IPNet, error) {
ips := make(map[string]struct{}, len(nets))
func parseIPWithCIDR(nets []string) (map[netip.Addr]struct{}, []*net.IPNet, error) {
ips := make(map[netip.Addr]struct{}, len(nets))
cidrs := make([]*net.IPNet, 0, len(nets))

for _, n := range nets {
Expand All @@ -145,7 +156,7 @@ func parseIPWithCIDR(nets []string) (map[string]struct{}, []*net.IPNet, error) {
continue
}
if ip := net.ParseIP(n); ip != nil {
ips[n] = struct{}{}
ips[netip.MustParseAddr(n)] = struct{}{}
continue
}
return nil, nil, fmt.Errorf("bad IP or CIDR: %q", n)
Expand Down
4 changes: 2 additions & 2 deletions handler_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package ipfilterware

import (
"net"
"net/http"
"net/http/httptest"
"net/netip"
"testing"
)

Expand Down Expand Up @@ -52,7 +52,7 @@ func TestSingleIP(t *testing.T) {
{"10.20.0.1", false},
}
for i, tc := range testCases {
if f.IsAllowed(net.ParseIP(tc.ip)) != tc.isAllowed {
if f.IsAllowed(netip.MustParseAddr(tc.ip)) != tc.isAllowed {
t.Errorf("[%d] ip %q must be %v", i, tc.ip, tc.isAllowed)
}
}
Expand Down

0 comments on commit 0950cf2

Please sign in to comment.