Skip to content

Commit

Permalink
feat(): add CF-Connecting-IP (#908)
Browse files Browse the repository at this point in the history
* feat(): add CF-Connection-IP

* feat(): custom real ip header

#908

* fix(): typo

* refactor(): remove cf-connecting-ip from default headers

* refactor(): back to unexported default headers
  • Loading branch information
n33pm authored Sep 18, 2024
1 parent 134f373 commit cbaac31
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 19 deletions.
51 changes: 32 additions & 19 deletions middleware/realip.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ import (
"strings"
)

var trueClientIP = http.CanonicalHeaderKey("True-Client-IP")
var xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For")
var xRealIP = http.CanonicalHeaderKey("X-Real-IP")
var defaultHeaders = []string{
"True-Client-IP", // Cloudflare Enterprise plan
"X-Real-IP",
"X-Forwarded-For",
}

// RealIP is a middleware that sets a http.Request's RemoteAddr to the results
// of parsing either the True-Client-IP, X-Real-IP or the X-Forwarded-For headers
Expand All @@ -30,7 +32,7 @@ var xRealIP = http.CanonicalHeaderKey("X-Real-IP")
// how you're using RemoteAddr, vulnerable to an attack of some sort).
func RealIP(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if rip := realIP(r); rip != "" {
if rip := getRealIP(r, defaultHeaders); rip != "" {
r.RemoteAddr = rip
}
h.ServeHTTP(w, r)
Expand All @@ -39,22 +41,33 @@ func RealIP(h http.Handler) http.Handler {
return http.HandlerFunc(fn)
}

func realIP(r *http.Request) string {
var ip string

if tcip := r.Header.Get(trueClientIP); tcip != "" {
ip = tcip
} else if xrip := r.Header.Get(xRealIP); xrip != "" {
ip = xrip
} else if xff := r.Header.Get(xForwardedFor); xff != "" {
i := strings.Index(xff, ",")
if i == -1 {
i = len(xff)
// RealIPFromHeaders is a middleware that sets a http.Request's RemoteAddr to the results
// of parsing the custom headers.
//
// usage:
// r.Use(RealIPFromHeaders("CF-Connecting-IP"))
func RealIPFromHeaders(headers ...string) func(http.Handler) http.Handler {
f := func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if rip := getRealIP(r, headers); rip != "" {
r.RemoteAddr = rip
}
h.ServeHTTP(w, r)
}
ip = xff[:i]
return http.HandlerFunc(fn)
}
if ip == "" || net.ParseIP(ip) == nil {
return ""
return f
}

func getRealIP(r *http.Request, headers []string) string {
for _, header := range headers {
if ip := r.Header.Get(header); ip != "" {
ips := strings.Split(ip, ",")
if ips[0] == "" || net.ParseIP(ips[0]) == nil {
continue
}
return ips[0]
}
}
return ip
return ""
}
49 changes: 49 additions & 0 deletions middleware/realip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,52 @@ func TestInvalidIP(t *testing.T) {
t.Fatal("Invalid IP used.")
}
}

func TestCustomIPHeader(t *testing.T) {
var customHeaderKey = "X-CUSTOM-IP"
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Add(customHeaderKey, "100.100.100.100")
w := httptest.NewRecorder()

r := chi.NewRouter()
r.Use(RealIPFromHeaders(customHeaderKey))

realIP := ""
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
realIP = r.RemoteAddr
w.Write([]byte("Hello World"))
})
r.ServeHTTP(w, req)

if w.Code != 200 {
t.Fatal("Response Code should be 200")
}

if realIP != "100.100.100.100" {
t.Fatal("Test get real IP precedence error.")
}
}

func TestCustomIPHeaderWithoutDefault(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Add("X-REAL-IP", "100.100.100.100")
w := httptest.NewRecorder()

r := chi.NewRouter()
r.Use(RealIPFromHeaders("CF-Connecting-IP"))

realIP := ""
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
realIP = r.RemoteAddr
w.Write([]byte("Hello World"))
})
r.ServeHTTP(w, req)

if w.Code != 200 {
t.Fatal("Response Code should be 200")
}

if realIP != "" {
t.Fatal("Invalid IP used.")
}
}

0 comments on commit cbaac31

Please sign in to comment.