Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change auth mechanism for websocket #25

Merged
merged 3 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 45 additions & 21 deletions internal/http_server/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func handleRefreshRequest(w http.ResponseWriter, r *http.Request) {
}
}

func auth(next http.Handler) http.Handler {
func bearerAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authorization := r.Header.Get("Authorization")

Expand All @@ -201,32 +201,56 @@ func auth(next http.Handler) http.Handler {

token := strings.Split(authorization, "Bearer ")[1]

jwtToken, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
return []byte(viper.GetString("http.auth.jwt.accessTokenSecret")), nil
}, jwt.WithValidMethods([]string{"HS256"}), jwt.WithIssuer("excubitor-backend"))
if checkToken(w, r, token) {
next.ServeHTTP(w, r)
}

if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
logger.Debug(fmt.Sprintf("Attempt to authenticate with expired token from %s!", r.RemoteAddr))
ReturnError(w, r, http.StatusUnauthorized, "Token expired!")
return
} else if errors.Is(err, jwt.ErrSignatureInvalid) {
logger.Warn(fmt.Sprintf("Attempt to authenticate with invalid signature from %s!", r.RemoteAddr))
} else {
logger.Debug(fmt.Sprintf("Attempt to authenticate with invalid token from %s! Reason: %s", r.RemoteAddr, err))
}
})
}

func queryAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")

ReturnError(w, r, http.StatusUnauthorized, "Invalid token!")
if token == "" {
logger.Debug(fmt.Sprintf("Attempt to authenticate with invalid token format from %s!", r.RemoteAddr))
ReturnError(w, r, http.StatusBadRequest, "Token of invalid format!")
return
}

user, err := jwtToken.Claims.GetSubject()
if err != nil {
logger.Warn(fmt.Sprintf("Couldn't read token subject from %s!", user))
if checkToken(w, r, token) {
next.ServeHTTP(w, r)
}

logger.Trace(fmt.Sprintf("User %s authenticated successfully using JWT token!", user))

next.ServeHTTP(w, r)
})
}

func checkToken(w http.ResponseWriter, r *http.Request, token string) bool {
jwtToken, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
return []byte(viper.GetString("http.auth.jwt.accessTokenSecret")), nil
}, jwt.WithValidMethods([]string{"HS256"}), jwt.WithIssuer("excubitor-backend"))

if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
logger.Debug(fmt.Sprintf("Attempt to authenticate with expired token from %s!", r.RemoteAddr))
ReturnError(w, r, http.StatusUnauthorized, "Token expired!")
return false
} else if errors.Is(err, jwt.ErrSignatureInvalid) {
logger.Warn(fmt.Sprintf("Attempt to authenticate with invalid signature from %s!", r.RemoteAddr))
} else {
logger.Debug(fmt.Sprintf("Attempt to authenticate with invalid token from %s! Reason: %s", r.RemoteAddr, err))
}

ReturnError(w, r, http.StatusUnauthorized, "Invalid token!")
return false
}

user, err := jwtToken.Claims.GetSubject()
if err != nil {
logger.Warn(fmt.Sprintf("Couldn't read token subject from %s!", user))
}

logger.Trace(fmt.Sprintf("User %s authenticated successfully using JWT token!", user))

return true
}
208 changes: 197 additions & 11 deletions internal/http_server/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ func TestHandleRefreshRequest(t *testing.T) {
assert.Equal(t, issuer, "excubitor-backend")
}

func TestAuthNoHeader(t *testing.T) {
func TestBearerAuthNoHeader(t *testing.T) {
var err error

logger, err = logging.GetConsoleLoggerInstance()
Expand All @@ -467,7 +467,7 @@ func TestAuthNoHeader(t *testing.T) {
req.RemoteAddr = "SampleAddress"
w := httptest.NewRecorder()

handler := auth(nil)
handler := bearerAuth(nil)
handler.ServeHTTP(w, req)

res := w.Result()
Expand All @@ -493,7 +493,7 @@ func TestAuthNoHeader(t *testing.T) {
assert.True(t, time.Since(httpError.Timestamp) < time.Since(time.Now().Add(-time.Second)) && time.Until(httpError.Timestamp) < 0)
}

func TestAuthInvalidHeader(t *testing.T) {
func TestBearerAuthInvalidHeader(t *testing.T) {
var err error

logger, err = logging.GetConsoleLoggerInstance()
Expand All @@ -506,7 +506,7 @@ func TestAuthInvalidHeader(t *testing.T) {
req.Header.Set("Authorization", "Basic dXNlcm5hbWU6cGFzc3dvcmQ=")
w := httptest.NewRecorder()

handler := auth(nil)
handler := bearerAuth(nil)
handler.ServeHTTP(w, req)

res := w.Result()
Expand All @@ -532,7 +532,7 @@ func TestAuthInvalidHeader(t *testing.T) {
assert.True(t, time.Since(httpError.Timestamp) < time.Since(time.Now().Add(-time.Second)) && time.Until(httpError.Timestamp) < 0)
}

func TestAuthTokenExpired(t *testing.T) {
func TestBearerAuthTokenExpired(t *testing.T) {
var err error

logger, err = logging.GetConsoleLoggerInstance()
Expand All @@ -559,7 +559,7 @@ func TestAuthTokenExpired(t *testing.T) {
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()

handler := auth(nil)
handler := bearerAuth(nil)
handler.ServeHTTP(w, req)

res := w.Result()
Expand All @@ -584,7 +584,7 @@ func TestAuthTokenExpired(t *testing.T) {
assert.True(t, time.Since(httpError.Timestamp) < time.Since(time.Now().Add(-time.Second)) && time.Until(httpError.Timestamp) < 0)
}

func TestAuthInvalidSignature(t *testing.T) {
func TestBearerAuthInvalidSignature(t *testing.T) {
var err error

logger, err = logging.GetConsoleLoggerInstance()
Expand Down Expand Up @@ -612,7 +612,7 @@ func TestAuthInvalidSignature(t *testing.T) {
req.Header.Set("Authorization", "Bearer "+signedToken)
w := httptest.NewRecorder()

handler := auth(nil)
handler := bearerAuth(nil)
handler.ServeHTTP(w, req)

res := w.Result()
Expand All @@ -637,7 +637,7 @@ func TestAuthInvalidSignature(t *testing.T) {
assert.True(t, time.Since(httpError.Timestamp) < time.Since(time.Now().Add(-time.Second)) && time.Until(httpError.Timestamp) < 0)
}

func TestAuth(t *testing.T) {
func TestBearerAuth(t *testing.T) {
var err error

logger, err = logging.GetConsoleLoggerInstance()
Expand All @@ -659,12 +659,198 @@ func TestAuth(t *testing.T) {
return
}

req := httptest.NewRequest(http.MethodPost, "/auth/refresh", nil)
req := httptest.NewRequest(http.MethodPost, "/someEndpoint", nil)
req.RemoteAddr = "SampleAddress"
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()

handler := auth(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
handler := bearerAuth(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
_, err := writer.Write([]byte{})
if err != nil {
t.Error(err)
return
}
}))
handler.ServeHTTP(w, req)

res := w.Result()
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
t.Error(err)
}
}(res.Body)

assert.Equal(t, http.StatusOK, res.StatusCode)
}

func TestQueryAuthNoToken(t *testing.T) {
var err error

logger, err = logging.GetConsoleLoggerInstance()
if err != nil {
t.Error(err)
}

req := httptest.NewRequest(http.MethodPost, "/someEndpoint?token=", nil)
req.RemoteAddr = "SampleAddress"
w := httptest.NewRecorder()

handler := queryAuth(nil)
handler.ServeHTTP(w, req)

res := w.Result()
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
t.Error(err)
}
}(res.Body)

body, err := io.ReadAll(res.Body)
if err != nil {
t.Error(err)
return
}

httpError := parseHTTPError(body)

assert.Equal(t, http.StatusBadRequest, res.StatusCode)
assert.Equal(t, "Token of invalid format!", httpError.Message)
assert.Equal(t, "/someEndpoint?token=", httpError.Path)
assert.True(t, time.Since(httpError.Timestamp) < time.Since(time.Now().Add(-time.Second)) && time.Until(httpError.Timestamp) < 0)
}

func TestQueryAuthTokenExpired(t *testing.T) {
var err error

logger, err = logging.GetConsoleLoggerInstance()
if err != nil {
t.Error(err)
return
}

viper.SetDefault("http.auth.jwt.accessTokenSecret", "123456")
viper.SetDefault("http.auth.jwt.refreshTokenSecret", "abcdef")

token, err := signAccessToken(jwt.MapClaims{
"iss": "excubitor-backend",
"sub": "testuser",
"exp": time.Now().Add(-30 * time.Minute).Unix(),
})
if err != nil {
t.Error(err)
return
}

req := httptest.NewRequest(http.MethodPost, "/someEndpoint?token="+token, nil)
req.RemoteAddr = "SampleAddress"
w := httptest.NewRecorder()

handler := queryAuth(nil)
handler.ServeHTTP(w, req)

res := w.Result()
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
t.Error(err)
}
}(res.Body)

body, err := io.ReadAll(res.Body)
if err != nil {
t.Error(err)
return
}

httpError := parseHTTPError(body)

assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
assert.Equal(t, "Token expired!", httpError.Message)
assert.Equal(t, "/someEndpoint?token="+token, httpError.Path)
assert.True(t, time.Since(httpError.Timestamp) < time.Since(time.Now().Add(-time.Second)) && time.Until(httpError.Timestamp) < 0)
}

func TestQueryAuthInvalidSignature(t *testing.T) {
var err error

logger, err = logging.GetConsoleLoggerInstance()
if err != nil {
t.Error(err)
return
}

viper.SetDefault("http.auth.jwt.accessTokenSecret", "123456")
viper.SetDefault("http.auth.jwt.refreshTokenSecret", "abcdef")

token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"iss": "excubitor-backend",
"sub": "testuser",
"exp": time.Now().Add(-4 * time.Hour).Unix(),
})
signedToken, err := token.SignedString([]byte("someOtherKey"))
if err != nil {
t.Error(err)
return
}

req := httptest.NewRequest(http.MethodPost, "/someEndpoint?token="+signedToken, nil)
req.RemoteAddr = "SampleAddress"
w := httptest.NewRecorder()

handler := queryAuth(nil)
handler.ServeHTTP(w, req)

res := w.Result()
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
t.Error(err)
}
}(res.Body)

body, err := io.ReadAll(res.Body)
if err != nil {
t.Error(err)
return
}

httpError := parseHTTPError(body)

assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
assert.Equal(t, "Invalid token!", httpError.Message)
assert.Equal(t, "/someEndpoint?token="+signedToken, httpError.Path)
assert.True(t, time.Since(httpError.Timestamp) < time.Since(time.Now().Add(-time.Second)) && time.Until(httpError.Timestamp) < 0)
}

func TestQueryAuth(t *testing.T) {
var err error

logger, err = logging.GetConsoleLoggerInstance()
if err != nil {
t.Error(err)
return
}

viper.SetDefault("http.auth.jwt.accessTokenSecret", "123456")
viper.SetDefault("http.auth.jwt.refreshTokenSecret", "abcdef")

token, err := signAccessToken(jwt.MapClaims{
"iss": "excubitor-backend",
"sub": "testuser",
"exp": time.Now().Add(30 * time.Minute).Unix(),
})
if err != nil {
t.Error(err)
return
}

req := httptest.NewRequest(http.MethodPost, "/someEndpoint?token="+token, nil)
req.RemoteAddr = "SampleAddress"
w := httptest.NewRecorder()

handler := queryAuth(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
_, err := writer.Write([]byte{})
if err != nil {
t.Error(err)
Expand Down
2 changes: 1 addition & 1 deletion internal/http_server/http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func Start() error {
mux.HandleFunc("/info", info)
mux.HandleFunc("/auth", handleAuthRequest)
mux.HandleFunc("/auth/refresh", handleRefreshRequest)
mux.Handle("/ws", auth(http.HandlerFunc(wsInit)))
mux.Handle("/ws", queryAuth(http.HandlerFunc(wsInit)))

cors := getCORSHandler()

Expand Down