Skip to content

Commit

Permalink
BED-4945: Update PATCH function for SSO (#929)
Browse files Browse the repository at this point in the history
* wip

* Altered UpdateUser and CreateUser functions and corresponding unit tests

* Added sso_provider to audit data and UserSessionAssociations

* Addressed PR feedback

* Addressed more PR feedback
  • Loading branch information
ALCooper12 authored Nov 4, 2024
1 parent 41583d2 commit 0932b83
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 64 deletions.
65 changes: 48 additions & 17 deletions cmd/api/src/api/v2/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -475,11 +475,29 @@ func (s ManagementResource) CreateUser(response http.ResponseWriter, request *ht
if createUserRequest.SAMLProviderID != "" {
if samlProviderID, err := serde.ParseInt32(createUserRequest.SAMLProviderID); err != nil {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, fmt.Sprintf("SAML Provider ID must be a number: %v", err.Error()), request), response)
return
} else if samlProvider, err := s.db.GetSAMLProvider(request.Context(), samlProviderID); err != nil {
log.Errorf("Error while attempting to fetch SAML provider %d: %v", createUserRequest.SAMLProviderID, err)
api.HandleDatabaseError(request, response, err)
return
} else {
userTemplate.SAMLProviderID = null.Int32From(samlProvider.ID)
userTemplate.SSOProviderID = samlProvider.SSOProviderID
}
} else if createUserRequest.SSOProviderID.Valid {
if ssoProvider, err := s.db.GetSSOProviderById(request.Context(), createUserRequest.SSOProviderID.Int32); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else {
userTemplate.SSOProviderID = createUserRequest.SSOProviderID
if ssoProvider.Type == model.SessionAuthProviderSAML {
if ssoProvider.SAMLProvider != nil {
userTemplate.SAMLProviderID = null.Int32From(ssoProvider.SAMLProvider.ID)
}
} else {
userTemplate.SAMLProvider = nil
userTemplate.SAMLProviderID = null.NewInt32(0, false)
}
}
}

Expand All @@ -492,20 +510,10 @@ func (s ManagementResource) CreateUser(response http.ResponseWriter, request *ht
}
}

func (s ManagementResource) updateUser(response http.ResponseWriter, request *http.Request, user model.User) {
if err := s.db.UpdateUser(request.Context(), user); err != nil {
api.HandleDatabaseError(request, response, err)
} else {
response.WriteHeader(http.StatusOK)
}
}

func (s ManagementResource) ensureUserHasNoAuthSecret(ctx context.Context, user model.User) error {
if user.AuthSecret != nil {
if err := s.db.DeleteAuthSecret(ctx, *user.AuthSecret); err != nil {
return api.FormatDatabaseError(err)
} else {
return nil
}
}

Expand Down Expand Up @@ -556,24 +564,47 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
// We're setting a SAML provider. If the user has an associated secret the secret will be removed.
if samlProviderID, err := serde.ParseInt32(updateUserRequest.SAMLProviderID); err != nil {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, fmt.Sprintf("SAML Provider ID must be a number: %v", err.Error()), request), response)
return
} else if err := s.ensureUserHasNoAuthSecret(request.Context(), user); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else if provider, err := s.db.GetSAMLProvider(request.Context(), samlProviderID); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else {
// Ensure that the AuthSecret reference is nil and that the SAML provider is set
user.AuthSecret = nil
user.SAMLProvider = &provider
user.SAMLProviderID = null.Int32From(samlProviderID)

s.updateUser(response, request, user)
user.SSOProviderID = provider.SSOProviderID
}
} else if updateUserRequest.SSOProviderID.Valid {
if err := s.ensureUserHasNoAuthSecret(request.Context(), user); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else if ssoProvider, err := s.db.GetSSOProviderById(request.Context(), updateUserRequest.SSOProviderID.Int32); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else {
user.SSOProviderID = updateUserRequest.SSOProviderID
if ssoProvider.Type == model.SessionAuthProviderSAML {
if ssoProvider.SAMLProvider != nil {
user.SAMLProviderID = null.Int32From(ssoProvider.SAMLProvider.ID)
}
} else {
user.SAMLProvider = nil
user.SAMLProviderID = null.NewInt32(0, false)
}
}
} else {
// Default SAMLProviderID to null if the update request contains no SAMLProviderID
user.SAMLProviderID = null.NewInt32(0, false)
// Default SAMLProviderID and SSOProviderID to null if the update request contains no SAMLProviderID and SSOProviderID
user.SAMLProvider = nil
user.SAMLProviderID = null.NewInt32(0, false)
user.SSOProviderID = null.NewInt32(0, false)
}

s.updateUser(response, request, user)
if err := s.db.UpdateUser(request.Context(), user); err != nil {
api.HandleDatabaseError(request, response, err)
} else {
response.WriteHeader(http.StatusOK)
}
}
}
Expand Down
90 changes: 53 additions & 37 deletions cmd/api/src/api/v2/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ import (

const (
samlProviderPathFmt = "/api/v2/saml/providers/%d"
updateUserPathFmt = "/api/v2/auth/users/%s"
updateUserSecretPathFmt = "/api/v2/auth/users/%s/secret"
ssoProviderID int32 = 123
samlProviderID int32 = 1234
samlProviderIDStr = "1234"
)
Expand Down Expand Up @@ -174,44 +174,60 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {

defer mockCtrl.Finish()

mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Eq(goodRoles)).Return(model.Roles{}, nil).AnyTimes()
mockDB.EXPECT().GetUser(gomock.Any(), badUserID).Return(model.User{AuthSecret: &model.AuthSecret{}}, nil)
mockDB.EXPECT().GetUser(gomock.Any(), goodUserID).Return(model.User{}, nil)
mockDB.EXPECT().GetSAMLProvider(gomock.Any(), samlProviderID).Return(model.SAMLProvider{}, nil).Times(2)
mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil).Times(2)
mockDB.EXPECT().DeleteAuthSecret(gomock.Any(), gomock.Any()).Return(nil)
t.Run("Successfully update user with deprecated saml provider", func(t *testing.T) {
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Eq(goodRoles)).Return(model.Roles{}, nil)
mockDB.EXPECT().GetUser(gomock.Any(), goodUserID).Return(model.User{}, nil)
mockDB.EXPECT().GetSAMLProvider(gomock.Any(), samlProviderID).Return(model.SAMLProvider{}, nil)
mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil)

test.Request(t).
WithURLPathVars(map[string]string{"user_id": goodUserID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Roles: goodRoles,
SAMLProviderID: samlProviderIDStr,
}).
OnHandlerFunc(resources.UpdateUser).
Require().
ResponseStatusCode(http.StatusOK)
})

// Happy path
test.Request(t).
WithMethod(http.MethodPut).
WithURL(fmt.Sprintf(updateUserPathFmt, goodUserID.String())). //nolint:govet // Ignore non-constant format string failure because it's test code
WithURLPathVars(map[string]string{
"user_id": goodUserID.String(),
}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Roles: goodRoles,
SAMLProviderID: samlProviderIDStr,
}).
OnHandlerFunc(resources.UpdateUser).
Require().
ResponseStatusCode(http.StatusOK)
t.Run("Fails if auth secret set", func(t *testing.T) {
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Eq(goodRoles)).Return(model.Roles{}, nil)
mockDB.EXPECT().GetUser(gomock.Any(), badUserID).Return(model.User{AuthSecret: &model.AuthSecret{}}, nil)
mockDB.EXPECT().GetSAMLProvider(gomock.Any(), samlProviderID).Return(model.SAMLProvider{}, nil)
mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil)
mockDB.EXPECT().DeleteAuthSecret(gomock.Any(), gomock.Any()).Return(nil)

test.Request(t).
WithURLPathVars(map[string]string{"user_id": badUserID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Roles: goodRoles,
SAMLProviderID: samlProviderIDStr,
}).
OnHandlerFunc(resources.UpdateUser).
Require().
ResponseStatusCode(http.StatusOK)
})

// Negative path where a user already has an auth secret set
test.Request(t).
WithMethod(http.MethodPut).
WithURL(fmt.Sprintf(updateUserPathFmt, badUserID.String())). //nolint:govet // Ignore non-constant format string failure because it's test code
WithURLPathVars(map[string]string{
"user_id": badUserID.String(),
}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Roles: goodRoles,
SAMLProviderID: samlProviderIDStr,
}).
OnHandlerFunc(resources.UpdateUser).
Require().
ResponseStatusCode(http.StatusOK)
t.Run("Successful user update with sso provider-saml", func(t *testing.T) {
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Eq(goodRoles)).Return(model.Roles{}, nil)
mockDB.EXPECT().GetUser(gomock.Any(), goodUserID).Return(model.User{}, nil)
mockDB.EXPECT().GetSSOProviderById(gomock.Any(), ssoProviderID).Return(model.SSOProvider{}, nil)
mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil)

test.Request(t).
WithURLPathVars(map[string]string{"user_id": goodUserID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Roles: goodRoles,
SSOProviderID: null.Int32From(123),
}).
OnHandlerFunc(resources.UpdateUser).
Require().
ResponseStatusCode(http.StatusOK)
})
}

func TestManagementResource_DeleteSAMLProvider(t *testing.T) {
Expand Down
16 changes: 9 additions & 7 deletions cmd/api/src/api/v2/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/specterops/bloodhound/src/auth"
"github.com/specterops/bloodhound/src/config"
"github.com/specterops/bloodhound/src/database"
"github.com/specterops/bloodhound/src/database/types/null"
"github.com/specterops/bloodhound/src/model"
"github.com/specterops/bloodhound/src/queries"
"github.com/specterops/bloodhound/src/serde"
Expand Down Expand Up @@ -59,13 +60,14 @@ type ListSAMLProvidersResponse struct {
}

type UpdateUserRequest struct {
FirstName string `json:"first_name"`
LastName string `json:"last_name"`
EmailAddress string `json:"email_address"`
Principal string `json:"principal"`
Roles []int32 `json:"roles"`
SAMLProviderID string `json:"saml_provider_id"`
IsDisabled bool `json:"is_disabled"`
FirstName string `json:"first_name"`
LastName string `json:"last_name"`
EmailAddress string `json:"email_address"`
Principal string `json:"principal"`
Roles []int32 `json:"roles"`
SAMLProviderID string `json:"saml_provider_id"`
SSOProviderID null.Int32 `json:"sso_provider_id"`
IsDisabled bool `json:"is_disabled"`
}

type CreateUserRequest struct {
Expand Down
9 changes: 9 additions & 0 deletions cmd/api/src/database/migration/migrations/v6.2.0.sql
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,12 @@ ALTER TABLE ONLY users
DROP CONSTRAINT IF EXISTS fk_users_saml_provider;
ALTER TABLE ONLY users
ADD CONSTRAINT fk_users_saml_provider FOREIGN KEY (saml_provider_id) REFERENCES saml_providers (id) ON DELETE SET NULL;

-- Backfill users with their proper sso_provider when they have a saml_provider_id
UPDATE users u
SET sso_provider_id = (SELECT sso.id
FROM saml_providers saml
JOIN sso_providers sso ON sso.id = saml.sso_provider_id
WHERE u.saml_provider_id = saml.id)
WHERE sso_provider_id IS NULL
AND saml_provider_id IS NOT NULL;
15 changes: 15 additions & 0 deletions cmd/api/src/database/mocks/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions cmd/api/src/database/sso_providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type SSOProviderData interface {
CreateSSOProvider(ctx context.Context, name string, authProvider model.SessionAuthProvider) (model.SSOProvider, error)
DeleteSSOProvider(ctx context.Context, id int) error
GetAllSSOProviders(ctx context.Context, order string, sqlFilter model.SQLFilter) ([]model.SSOProvider, error)
GetSSOProviderById(ctx context.Context, id int32) (model.SSOProvider, error)
GetSSOProviderBySlug(ctx context.Context, slug string) (model.SSOProvider, error)
GetSSOProviderUsers(ctx context.Context, id int) (model.Users, error)
}
Expand Down Expand Up @@ -137,3 +138,10 @@ func (s *BloodhoundDB) GetSSOProviderUsers(ctx context.Context, id int) (model.U

return users, CheckError(s.db.WithContext(ctx).Table("users").Where("sso_provider_id = ?", id).Find(&users))
}

func (s *BloodhoundDB) GetSSOProviderById(ctx context.Context, id int32) (model.SSOProvider, error) {
var provider model.SSOProvider
result := s.db.WithContext(ctx).Preload("OIDCProvider").Preload("SAMLProvider").Table(ssoProviderTableName).Where("id = ?", id).First(&provider)

return provider, CheckError(result)
}
22 changes: 22 additions & 0 deletions cmd/api/src/database/sso_providers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,25 @@ func TestBloodhoundDB_GetSSOProviderUsers(t *testing.T) {
assert.Equal(t, user.ID, returnedUsers[0].ID)
})
}

func TestBloodhoundDB_GetSSOProviderById(t *testing.T) {
var (
testCtx = context.Background()
dbInst = integration.SetupDB(t)
)
defer dbInst.Close(testCtx)

t.Run("successfully get sso provider by id", func(t *testing.T) {
newSamlProvider, err := dbInst.CreateSAMLIdentityProvider(testCtx, model.SAMLProvider{
Name: "someName",
DisplayName: "someName",
})
require.NoError(t, err)

provider, err := dbInst.GetSSOProviderById(testCtx, newSamlProvider.SSOProviderID.Int32)
require.NoError(t, err)

require.EqualValues(t, newSamlProvider.SSOProviderID.Int32, provider.ID)
require.NotNil(t, provider.SAMLProvider)
})
}
8 changes: 5 additions & 3 deletions cmd/api/src/model/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,9 +470,9 @@ type User struct {
IsDisabled bool `json:"is_disabled"`
// EULA Acceptance does not pertain to Bloodhound Community Edition; this flag is used for Bloodhound Enterprise users.
// This value is automatically set to true for Bloodhound Community Edition in the patchEULAAcceptance and CreateUser functions.
EULAAccepted bool `json:"eula_accepted"`

SSOProviderID null.Int32 `json:"sso_provider_id,omitempty"`
EULAAccepted bool `json:"eula_accepted"`
SSOProvider *SSOProvider `json:"-" `
SSOProviderID null.Int32 `json:"sso_provider_id,omitempty"`

Unique
}
Expand All @@ -486,6 +486,7 @@ func (s *User) AuditData() AuditData {
"email_address": s.EmailAddress.ValueOrZero(),
"roles": s.Roles.IDs(),
"saml_provider_id": s.SAMLProviderID.ValueOrZero(),
"sso_provider_id": s.SSOProviderID.ValueOrZero(),
"is_disabled": s.IsDisabled,
"eula_accepted": s.EULAAccepted,
}
Expand Down Expand Up @@ -562,6 +563,7 @@ func (s Users) GetValidFilterPredicatesAsStrings(column string) ([]string, error
func UserSessionAssociations() []string {
return []string{
"User.SAMLProvider",
"User.SSOProvider",
"User.AuthSecret",
"User.AuthTokens",
"User.Roles.Permissions",
Expand Down

0 comments on commit 0932b83

Please sign in to comment.