Skip to content

Commit

Permalink
Adding unit tests for the callback handler (flyteorg#189)
Browse files Browse the repository at this point in the history
* Adding unit tests for the callback handler

Signed-off-by: Prafulla Mahindrakar <prafulla.mahindrakar@gmail.com>

* Fixed linter issues

Signed-off-by: Prafulla Mahindrakar <prafulla.mahindrakar@gmail.com>
  • Loading branch information
pmahindrakar-oss authored May 5, 2021
1 parent 8235d1f commit db32bb3
Showing 1 changed file with 169 additions and 6 deletions.
175 changes: 169 additions & 6 deletions auth/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,192 @@ package auth

import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"

"github.com/stretchr/testify/mock"

stdConfig "github.com/flyteorg/flytestdlib/config"
"testing"

"github.com/flyteorg/flyteadmin/auth/config"
"github.com/flyteorg/flyteadmin/auth/interfaces/mocks"
"github.com/flyteorg/flyteadmin/pkg/common"
stdConfig "github.com/flyteorg/flytestdlib/config"

"github.com/flyteorg/flyteadmin/auth/interfaces/mocks"
"github.com/coreos/go-oidc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/oauth2"
)

"testing"
const (
oauth2TokenURL = "/oauth2/token" // #nosec
)

func TestWithUserEmail(t *testing.T) {
ctx := WithUserEmail(context.Background(), "abc")
assert.Equal(t, "abc", ctx.Value(common.PrincipalContextKey))
}

func setupMockedAuthContextAtEndpoint(endpoint string) *mocks.AuthenticationContext {
mockAuthCtx := &mocks.AuthenticationContext{}
mockAuthCtx.OnOptions().Return(&config.Config{})
mockCookieHandler := new(mocks.CookieHandler)
dummyOAuth2Config := oauth2.Config{
ClientID: "abc",
Endpoint: oauth2.Endpoint{
AuthURL: endpoint + "/oauth2/authorize",
TokenURL: endpoint + oauth2TokenURL,
},
Scopes: []string{"openid", "other"},
}
mockAuthCtx.OnCookieManagerMatch().Return(mockCookieHandler)
mockCookieHandler.OnSetTokenCookiesMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil)
mockCookieHandler.OnSetUserInfoCookieMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil)
mockAuthCtx.OnOAuth2ClientConfigMatch(mock.Anything).Return(&dummyOAuth2Config)
return mockAuthCtx
}

func addStateString(request *http.Request) {
v := url.Values{
"state": []string{"b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"},
}
request.Form = v
}

func addCsrfCookie(request *http.Request) {
cookie := NewCsrfCookie()
cookie.Value = "hello world"
request.AddCookie(&cookie)
}

func TestGetCallbackHandlerWithErrorOnToken(t *testing.T) {
ctx := context.Background()
hf := func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == oauth2TokenURL {
w.WriteHeader(403)
return
}
}
localServer := httptest.NewServer(http.HandlerFunc(hf))
defer localServer.Close()
http.DefaultClient = localServer.Client()
mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL)
callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx)
request := httptest.NewRequest("GET", localServer.URL+"/callback", nil)
addCsrfCookie(request)
addStateString(request)
writer := httptest.NewRecorder()
callbackHandlerFunc(writer, request)
assert.Equal(t, "403 Forbidden", writer.Result().Status)
}

func TestGetCallbackHandlerWithUnAuthorized(t *testing.T) {
ctx := context.Background()
hf := func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == oauth2TokenURL {
w.WriteHeader(403)
return
}
}
localServer := httptest.NewServer(http.HandlerFunc(hf))
defer localServer.Close()
http.DefaultClient = localServer.Client()
mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL)
callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx)
request := httptest.NewRequest("GET", localServer.URL+"/callback", nil)
writer := httptest.NewRecorder()
callbackHandlerFunc(writer, request)
assert.Equal(t, "401 Unauthorized", writer.Result().Status)
}

func TestGetCallbackHandler(t *testing.T) {
var openIDConfigJSON string
var userInfoJSON string
ctx := context.Background()
hf := func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == oauth2TokenURL {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"Sample.Access.Token",
"issued_token_type":"urn:ietf:params:oauth:token-type:access_token",
"token_type":"Bearer",
"expires_in":3600,
"scope":"all"}`)
return
}
if r.URL.Path == "/.well-known/openid-configuration" {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, openIDConfigJSON)
return
}
if r.URL.Path == "/userinfo" {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, userInfoJSON)
return
}
}
localServer := httptest.NewServer(http.HandlerFunc(hf))
defer localServer.Close()
http.DefaultClient = localServer.Client()
issuer := localServer.URL
userInfoJSON = `{
"subject" : "dummySubject",
"profile" : "dummyProfile",
"email" : "dummyEmail"
}`
openIDConfigJSON = fmt.Sprintf(`{
"issuer": "%v",
"authorization_endpoint": "%v/auth",
"token_endpoint": "%v/token",
"jwks_uri": "%v/keys",
"id_token_signing_alg_values_supported": ["RS256"]
}`, issuer, issuer, issuer, issuer)

t.Run("forbidden request when accessing user info", func(t *testing.T) {
mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL)
callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx)
request := httptest.NewRequest("GET", localServer.URL+"/callback", nil)
addCsrfCookie(request)
addStateString(request)
writer := httptest.NewRecorder()
openIDConfigJSON = fmt.Sprintf(`{
"issuer": "%v",
"authorization_endpoint": "%v/auth",
"token_endpoint": "%v/token",
"jwks_uri": "%v/keys",
"id_token_signing_alg_values_supported": ["RS256"]
}`, issuer, issuer, issuer, issuer)
oidcProvider, err := oidc.NewProvider(ctx, issuer)
assert.Nil(t, err)
mockAuthCtx.OnOidcProviderMatch().Return(oidcProvider)
callbackHandlerFunc(writer, request)
assert.Equal(t, "403 Forbidden", writer.Result().Status)
})

t.Run("successful callback and redirect", func(t *testing.T) {
mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL)
callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx)
request := httptest.NewRequest("GET", localServer.URL+"/callback", nil)
addCsrfCookie(request)
addStateString(request)
writer := httptest.NewRecorder()
openIDConfigJSON = fmt.Sprintf(`{
"userinfo_endpoint": "%v/userinfo",
"issuer": "%v",
"authorization_endpoint": "%v/auth",
"token_endpoint": "%v/token",
"jwks_uri": "%v/keys",
"id_token_signing_alg_values_supported": ["RS256"]
}`, issuer, issuer, issuer, issuer, issuer)
oidcProvider, err := oidc.NewProvider(ctx, issuer)
assert.Nil(t, err)
mockAuthCtx.OnOidcProviderMatch().Return(oidcProvider)
callbackHandlerFunc(writer, request)
assert.Equal(t, "307 Temporary Redirect", writer.Result().Status)
})
}

func TestGetLoginHandler(t *testing.T) {
ctx := context.Background()
dummyOAuth2Config := oauth2.Config{
Expand Down

0 comments on commit db32bb3

Please sign in to comment.