Skip to content

Commit

Permalink
Enable JWT group-based user authorization (#1368)
Browse files Browse the repository at this point in the history
* Extend management API to support list of allowed JWT groups (#1366)

* Add JWTAllowGroups settings to account management

* Return an empty group list if jwt allow groups is not set

* Add JwtAllowGroups to account settings in handler test

* Add JWT group-based user authorization (#1373)

* Add JWTAllowGroups settings to account management

* Return an empty group list if jwt allow groups is not set

* Add JwtAllowGroups to account settings in handler test

* Implement user access validation authentication based on JWT groups

* Remove the slices package import due to compatibility issues with the gitHub workflow(s) Go version

* Refactor auth middleware and test for extracted claim handling

* Optimize JWT group check in auth middleware to cover nil and empty allowed groups
  • Loading branch information
bcmmbaga authored Dec 11, 2023
1 parent 5ecafef commit d275d41
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 10 deletions.
4 changes: 4 additions & 0 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ type Settings struct {
// JWTGroupsClaimName from which we extract groups name to add it to account groups
JWTGroupsClaimName string

// JWTAllowGroups list of groups to which users are allowed access
JWTAllowGroups []string `gorm:"serializer:json"`

// Extra is a dictionary of Account settings
Extra *account.ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"`
}
Expand All @@ -176,6 +179,7 @@ func (s *Settings) Copy() *Settings {
JWTGroupsEnabled: s.JWTGroupsEnabled,
JWTGroupsClaimName: s.JWTGroupsClaimName,
GroupsPropagationEnabled: s.GroupsPropagationEnabled,
JWTAllowGroups: s.JWTAllowGroups,
}
if s.Extra != nil {
settings.Extra = s.Extra.Copy()
Expand Down
9 changes: 9 additions & 0 deletions management/server/http/accounts_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
if req.Settings.JwtGroupsClaimName != nil {
settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
}
if req.Settings.JwtAllowGroups != nil {
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
}

updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings)
if err != nil {
Expand Down Expand Up @@ -128,12 +131,18 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
}

func toAccountResponse(account *server.Account) *api.Account {
jwtAllowGroups := account.Settings.JWTAllowGroups
if jwtAllowGroups == nil {
jwtAllowGroups = []string{}
}

settings := api.AccountSettings{
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled,
JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled,
JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName,
JwtAllowGroups: &jwtAllowGroups,
}

if account.Settings.Extra != nil {
Expand Down
6 changes: 5 additions & 1 deletion management/server/http/accounts_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
GroupsPropagationEnabled: br(false),
JwtGroupsClaimName: sr(""),
JwtGroupsEnabled: br(false),
JwtAllowGroups: &[]string{},
},
expectedArray: true,
expectedID: accountID,
Expand All @@ -112,6 +113,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
GroupsPropagationEnabled: br(false),
JwtGroupsClaimName: sr(""),
JwtGroupsEnabled: br(false),
JwtAllowGroups: &[]string{},
},
expectedArray: false,
expectedID: accountID,
Expand All @@ -121,14 +123,15 @@ func TestAccounts_AccountsHandler(t *testing.T) {
expectedBody: true,
requestType: http.MethodPut,
requestPath: "/api/accounts/" + accountID,
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\"}}"),
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"]}}"),
expectedStatus: http.StatusOK,
expectedSettings: api.AccountSettings{
PeerLoginExpiration: 15552000,
PeerLoginExpirationEnabled: false,
GroupsPropagationEnabled: br(false),
JwtGroupsClaimName: sr("roles"),
JwtGroupsEnabled: br(true),
JwtAllowGroups: &[]string{"test"},
},
expectedArray: false,
expectedID: accountID,
Expand All @@ -146,6 +149,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
GroupsPropagationEnabled: br(true),
JwtGroupsClaimName: sr("groups"),
JwtGroupsEnabled: br(true),
JwtAllowGroups: &[]string{},
},
expectedArray: false,
expectedID: accountID,
Expand Down
6 changes: 6 additions & 0 deletions management/server/http/api/openapi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ components:
description: Name of the claim from which we extract groups names to add it to account groups.
type: string
example: "roles"
jwt_allow_groups:
description: List of groups to which users are allowed access
type: array
items:
type: string
example: Administrators
extra:
$ref: '#/components/schemas/AccountExtraSettings'
required:
Expand Down
3 changes: 3 additions & 0 deletions management/server/http/api/types.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 9 additions & 6 deletions management/server/http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,20 @@ type emptyObject struct {

// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
)

authMiddleware := middleware.NewAuthMiddleware(
accountManager.GetAccountFromPAT,
jwtValidator.ValidateAndParse,
accountManager.MarkPATUsed,
accountManager.GetAccountFromToken,
claimsExtractor,
authCfg.Audience,
authCfg.UserIDClaim)
authCfg.UserIDClaim,
)

corsMiddleware := cors.AllowAll()

Expand All @@ -60,11 +68,6 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid
AuthCfg: authCfg,
}

claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
)

integrations.RegisterHandlers(api.Router, accountManager, claimsExtractor)
api.addAccountsEndpoint()
api.addPeersEndpoint()
Expand Down
63 changes: 62 additions & 1 deletion management/server/http/middleware/auth_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,16 @@ type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
// MarkPATUsedFunc function
type MarkPATUsedFunc func(token string) error

// GetAccountFromTokenFunc function
type GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)

// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
getAccountFromPAT GetAccountFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc
getAccountFromToken GetAccountFromTokenFunc
claimsExtractor *jwtclaims.ClaimsExtractor
audience string
userIDClaim string
}
Expand All @@ -40,14 +45,19 @@ const (
)

// NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string, userIdClaim string) *AuthMiddleware {
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
markPATUsed MarkPATUsedFunc, getAccountFromToken GetAccountFromTokenFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
audience string, userIdClaim string) *AuthMiddleware {
if userIdClaim == "" {
userIdClaim = jwtclaims.UserIDClaim
}

return &AuthMiddleware{
getAccountFromPAT: getAccountFromPAT,
validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed,
getAccountFromToken: getAccountFromToken,
claimsExtractor: claimsExtractor,
audience: audience,
userIDClaim: userIdClaim,
}
Expand Down Expand Up @@ -107,6 +117,10 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
return nil
}

if err := m.verifyUserAccess(validatedToken); err != nil {
return err
}

// If we get here, everything worked and we can set the
// user property in context.
newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint
Expand All @@ -115,6 +129,41 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
return nil
}

// verifyUserAccess checks if a user, based on a validated JWT token,
// is allowed access, particularly in cases where the admin enabled JWT
// group propagation and designated certain groups with access permissions.
func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error {
authClaims := m.claimsExtractor.FromToken(validatedToken)
account, _, err := m.getAccountFromToken(authClaims)
if err != nil {
return fmt.Errorf("failed to get the account from token: %w", err)
}

// Ensures JWT group synchronization to the management is enabled before,
// filtering access based on the allowed groups.
if account.Settings != nil && account.Settings.JWTGroupsEnabled {
if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
userJWTGroups := make([]string, 0)

if claim, ok := authClaims.Raw[account.Settings.JWTGroupsClaimName]; ok {
if claimGroups, ok := claim.([]interface{}); ok {
for _, g := range claimGroups {
if group, ok := g.(string); ok {
userJWTGroups = append(userJWTGroups, group)
}
}
}
}

if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
}
}
}

return nil
}

// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
token, err := getTokenFromPATRequest(auth)
Expand Down Expand Up @@ -168,3 +217,15 @@ func getTokenFromPATRequest(authHeaderParts []string) (string, error) {

return authHeaderParts[1], nil
}

// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
for _, userGroup := range userGroups {
for _, allowedGroup := range allowedGroups {
if userGroup == allowedGroup {
return true
}
}
}
return false
}
37 changes: 35 additions & 2 deletions management/server/http/middleware/auth_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/golang-jwt/jwt"

"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
)

const (
Expand Down Expand Up @@ -54,7 +55,13 @@ func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server

func mockValidateAndParseToken(token string) (*jwt.Token, error) {
if token == JWT {
return &jwt.Token{}, nil
return &jwt.Token{
Claims: jwt.MapClaims{
userIDClaim: userID,
audience + jwtclaims.AccountIDSuffix: accountID,
},
Valid: true,
}, nil
}
return nil, fmt.Errorf("JWT invalid")
}
Expand All @@ -66,6 +73,19 @@ func mockMarkPATUsed(token string) error {
return fmt.Errorf("Should never get reached")
}

func mockGetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
if testAccount.Id != claims.AccountId {
return nil, nil, fmt.Errorf("account with id %s does not exist", claims.AccountId)
}

user, ok := testAccount.Users[claims.UserId]
if !ok {
return nil, nil, fmt.Errorf("user with id %s does not exist", claims.UserId)
}

return testAccount, user, nil
}

func TestAuthMiddleware_Handler(t *testing.T) {
tt := []struct {
name string
Expand Down Expand Up @@ -108,7 +128,20 @@ func TestAuthMiddleware_Handler(t *testing.T) {
// do nothing
})

authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience, userIDClaim)
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(audience),
jwtclaims.WithUserIDClaim(userIDClaim),
)

authMiddleware := NewAuthMiddleware(
mockGetAccountFromPAT,
mockValidateAndParseToken,
mockMarkPATUsed,
mockGetAccountFromToken,
claimsExtractor,
audience,
userIDClaim,
)

handlerToTest := authMiddleware.Handler(nextHandler)

Expand Down

0 comments on commit d275d41

Please sign in to comment.