Skip to content

Commit

Permalink
feat: add subject template group api
Browse files Browse the repository at this point in the history
  • Loading branch information
zhu327 committed Oct 25, 2023
1 parent 82a7661 commit 952efb8
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 164 deletions.
91 changes: 33 additions & 58 deletions pkg/abac/pap/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,9 @@ func (c *groupController) BulkCreateSubjectTemplateGroup(subjectTemplateGroups [

subjectGroupHelper := newSubjectGroupHelper(c.service)

for i, relation := range relations {
for i := range relations {
relation := &relations[i]

authorized, subjectGroup, err := subjectGroupHelper.getSubjectGroup(relation.SubjectPK, relation.GroupPK)
if err != nil {
return errorWrapf(
Expand All @@ -501,12 +503,12 @@ func (c *groupController) BulkCreateSubjectTemplateGroup(subjectTemplateGroups [

// 2. 如果已授权并且过期时间大于当前时间, 不需要更新
if subjectGroup != nil && subjectGroup.ExpiredAt > relation.ExpiredAt {
relations[i].ExpiredAt = subjectGroup.ExpiredAt
relation.ExpiredAt = subjectGroup.ExpiredAt
continue
}

// 3. 其余场景需要更新
relations[i].NeedUpdate = true
relation.NeedUpdate = true
}

tx, err := database.GenerateDefaultDBTx()
Expand Down Expand Up @@ -581,14 +583,16 @@ func (c *groupController) BulkDeleteSubjectTemplateGroup(subjectTemplateGroups [
}

// 查询是否有其它的关系
for i, relation := range relations {
exist, err := c.service.HasRelationExceptTemplate(relation)
for i := range relations {
relation := &relations[i]

exist, err := c.service.HasRelationExceptTemplate(*relation)
if err != nil {
return errorWrapf(err, "service.HasRelationExceptTemplate relation=`%+v` fail", relation)
}

if !exist {
relations[i].NeedUpdate = true
relation.NeedUpdate = true
}
}

Expand Down Expand Up @@ -693,73 +697,58 @@ func (c *groupController) alterGroupMembers(
}

// 获取实际需要添加的member
createMembers := make([]types.SubjectRelationForCreate, 0, len(members))
createMembers := make([]types.SubjectTemplateGroup, 0, len(members))

// 需要更新过期时间的member
updateMembers := make([]types.SubjectRelationForUpdate, 0, len(members))

// 用于清理缓存
subjectPKs := make([]int64, 0, len(members))
updateMembers := make([]types.SubjectTemplateGroup, 0, len(members))

typeCount = map[string]int64{
types.UserType: 0,
types.DepartmentType: 0,
}

subjectGroupHelper := newSubjectGroupHelper(c.service)
subjectTemplateGroups, err := c.convertGroupMembersToSubjectTemplateGroups(groupPK, members)
if err != nil {
return nil, err
}

for i, m := range members {
subjectPK := subjectTemplateGroups[i].SubjectPK
authorized, subjectGroup, err := subjectGroupHelper.getSubjectGroup(subjectPK, groupPK)
subjectGroupHelper := newSubjectGroupHelper(c.service)
for i := range subjectTemplateGroups {
relation := &subjectTemplateGroups[i]

// 查询 subject group 已有的关系
authorized, subjectGroup, err := subjectGroupHelper.getSubjectGroup(relation.SubjectPK, groupPK)
if err != nil {
return nil, errorWrapf(
err,
"getSubjectGroup subjectPK=`%d`, groupPK=`%d` fail",
subjectPK,
relation.SubjectPK,
groupPK,
)
}

if authorized && subjectGroup != nil && subjectGroup.ExpiredAt > m.ExpiredAt {
m.ExpiredAt = subjectGroup.ExpiredAt
subjectTemplateGroups[i].ExpiredAt = subjectGroup.ExpiredAt
if authorized && subjectGroup != nil && subjectGroup.ExpiredAt > relation.ExpiredAt {
relation.ExpiredAt = subjectGroup.ExpiredAt
}

// member已存在则不再添加
if oldMember, ok := memberMap[subjectPK]; ok {
if oldMember, ok := memberMap[relation.SubjectPK]; ok {
// 如果过期时间大于已有的时间, 则更新过期时间
if m.ExpiredAt > oldMember.ExpiredAt {
updateMembers = append(updateMembers, types.SubjectRelationForUpdate{
PK: oldMember.PK,
SubjectPK: subjectPK,
ExpiredAt: m.ExpiredAt,
})

subjectPKs = append(subjectPKs, subjectPK)
}
if relation.ExpiredAt > oldMember.ExpiredAt {
relation.NeedUpdate = true

if authorized && (subjectGroup == nil || subjectGroup.ExpiredAt < m.ExpiredAt) {
subjectTemplateGroups[i].NeedUpdate = true
updateMembers = append(updateMembers, *relation)
}
continue
}

if createIfNotExists {
createMembers = append(createMembers, types.SubjectRelationForCreate{
SubjectPK: subjectPK,
GroupPK: groupPK,
ExpiredAt: m.ExpiredAt,
})
typeCount[m.Type]++
subjectPKs = append(subjectPKs, subjectPK)

if authorized && (subjectGroup == nil || subjectGroup.ExpiredAt < m.ExpiredAt) {
subjectTemplateGroups[i].NeedUpdate = true
if authorized && (subjectGroup == nil || subjectGroup.ExpiredAt < relation.ExpiredAt) {
relation.NeedUpdate = true
}

createMembers = append(createMembers, *relation)
typeCount[members[i].Type]++
}
}

Expand Down Expand Up @@ -803,18 +792,8 @@ func (c *groupController) alterGroupMembers(
return nil, errorWrapf(err, "tx commit error")
}

needUpdateSubjectPKs := make([]int64, 0, len(subjectPKs))
for _, group := range subjectTemplateGroups {
if group.NeedUpdate {
needUpdateSubjectPKs = append(needUpdateSubjectPKs, group.SubjectPK)
}
}

// 创建group_alter_event
c.createGroupAlterEvent(groupPK, needUpdateSubjectPKs)

// 清理subject system group 缓存
cacheimpls.BatchDeleteSubjectAuthSystemGroupCache(needUpdateSubjectPKs, groupPK)
c.deleteSubjectTemplateGroupCache(subjectTemplateGroups)

return typeCount, nil
}
Expand All @@ -824,14 +803,10 @@ func (c *groupController) updateSubjectGroupExpiredAtWithTx(
subjectTemplateGroups []types.SubjectTemplateGroup,
updateGroupRelation bool,
) error {
needUpdateRelations := make([]types.SubjectRelationForCreate, 0, len(subjectTemplateGroups))
needUpdateRelations := make([]types.SubjectTemplateGroup, 0, len(subjectTemplateGroups))
for _, relation := range subjectTemplateGroups {
if relation.NeedUpdate {
needUpdateRelations = append(needUpdateRelations, types.SubjectRelationForCreate{
SubjectPK: relation.SubjectPK,
GroupPK: relation.GroupPK,
ExpiredAt: relation.ExpiredAt,
})
needUpdateRelations = append(needUpdateRelations, relation)
}
}

Expand Down
18 changes: 6 additions & 12 deletions pkg/abac/pap/group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ var _ = Describe("GroupController", func() {
mockGroupService.EXPECT().
UpdateGroupMembersExpiredAtWithTx(
gomock.Any(), int64(1),
[]types.SubjectRelationForUpdate{{PK: 1, SubjectPK: 2, ExpiredAt: 3}},
[]types.SubjectTemplateGroup{{SubjectPK: 2, GroupPK: 1, ExpiredAt: 3, NeedUpdate: true}},
).
Return(
errors.New("error"),
Expand Down Expand Up @@ -126,13 +126,13 @@ var _ = Describe("GroupController", func() {
[]types.GroupMember{}, nil,
).AnyTimes()
mockGroupService.EXPECT().
UpdateGroupMembersExpiredAtWithTx(gomock.Any(), int64(1), []types.SubjectRelationForUpdate{{PK: 1, SubjectPK: 2, ExpiredAt: 3}}).
UpdateGroupMembersExpiredAtWithTx(gomock.Any(), int64(1), []types.SubjectTemplateGroup{{SubjectPK: 2, GroupPK: 1, ExpiredAt: 3}}).
Return(
nil,
).
AnyTimes()
mockGroupService.EXPECT().
BulkCreateGroupMembersWithTx(gomock.Any(), int64(1), []types.SubjectRelationForCreate{{
BulkCreateGroupMembersWithTx(gomock.Any(), int64(1), []types.SubjectTemplateGroup{{
SubjectPK: 2,
GroupPK: 1,
ExpiredAt: int64(3),
Expand Down Expand Up @@ -175,16 +175,13 @@ var _ = Describe("GroupController", func() {
mockGroupService.EXPECT().
UpdateGroupMembersExpiredAtWithTx(
gomock.Any(), int64(1),
[]types.SubjectRelationForUpdate{{PK: 1, SubjectPK: 2, ExpiredAt: 3}},
[]types.SubjectTemplateGroup{{SubjectPK: 2, GroupPK: 1, ExpiredAt: 3}},
).Return(
nil,
).
AnyTimes()
mockGroupService.EXPECT().ListGroupAuthSystemIDs(int64(1)).Return([]string{}, nil).AnyTimes()
mockGroupAlterEventService := mock.NewMockGroupAlterEventService(ctl)
mockGroupAlterEventService.EXPECT().
CreateByGroupSubject(gomock.Any(), gomock.Any()).
Return(errors.New("error"))
mockGroupService.EXPECT().GetGroupOneAuthSystem(int64(1)).Return("", nil).AnyTimes()

patches.ApplyFunc(service.NewGroupService, func() service.GroupService {
Expand Down Expand Up @@ -223,14 +220,14 @@ var _ = Describe("GroupController", func() {
mockGroupService.EXPECT().
UpdateGroupMembersExpiredAtWithTx(
gomock.Any(), int64(1),
[]types.SubjectRelationForUpdate{{PK: 1, SubjectPK: 2, ExpiredAt: 3}},
[]types.SubjectTemplateGroup{{SubjectPK: 2, GroupPK: 1, ExpiredAt: 3}},
).
Return(
nil,
).
AnyTimes()
mockGroupService.EXPECT().
BulkCreateGroupMembersWithTx(gomock.Any(), int64(1), []types.SubjectRelationForCreate{{
BulkCreateGroupMembersWithTx(gomock.Any(), int64(1), []types.SubjectTemplateGroup{{
SubjectPK: 2,
GroupPK: 1,
ExpiredAt: int64(3),
Expand All @@ -241,9 +238,6 @@ var _ = Describe("GroupController", func() {
AnyTimes()
mockGroupService.EXPECT().ListGroupAuthSystemIDs(int64(1)).Return([]string{}, nil).AnyTimes()
mockGroupAlterEventService := mock.NewMockGroupAlterEventService(ctl)
mockGroupAlterEventService.EXPECT().
CreateByGroupSubject(gomock.Any(), gomock.Any()).
Return(errors.New("error"))
mockGroupService.EXPECT().GetGroupOneAuthSystem(int64(1)).Return("", nil).AnyTimes()

patches.ApplyFunc(service.NewGroupService, func() service.GroupService {
Expand Down
14 changes: 0 additions & 14 deletions pkg/database/dao/mock/subject_group.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 0 additions & 17 deletions pkg/database/dao/subject_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,6 @@ type SubjectRelation struct {
CreatedAt time.Time `db:"created_at"`
}

// SubjectRelationForUpdateExpiredAt keep the PrimaryKey and policy_expired_at
type SubjectRelationForUpdateExpiredAt struct {
PK int64 `db:"pk"`
// NOTE: map policy_expired_at to ExpiredAt in dao
ExpiredAt int64 `db:"policy_expired_at"`
}

// ThinSubjectRelation with the minimum fields of the relationship: subject-group-policy_expired_at
type ThinSubjectRelation struct {
SubjectPK int64 `db:"subject_pk"`
Expand Down Expand Up @@ -75,7 +68,6 @@ type SubjectGroupManager interface {

FilterGroupPKsHasMemberBeforeExpiredAt(groupPKs []int64, expiredAt int64) ([]int64, error)

UpdateExpiredAtWithTx(tx *sqlx.Tx, relations []SubjectRelationForUpdateExpiredAt) error
BulkCreateWithTx(tx *sqlx.Tx, relations []SubjectRelation) error
BulkDeleteBySubjectPKs(tx *sqlx.Tx, subjectPKs []int64) error
BulkDeleteByGroupPKs(tx *sqlx.Tx, groupPKs []int64) error
Expand Down Expand Up @@ -395,15 +387,6 @@ func (m *subjectGroupManager) BulkDeleteByGroupPKs(tx *sqlx.Tx, groupPKs []int64
return m.bulkDeleteByGroupPKs(tx, groupPKs)
}

// UpdateExpiredAtWithTx ...
func (m *subjectGroupManager) UpdateExpiredAtWithTx(
tx *sqlx.Tx,
relations []SubjectRelationForUpdateExpiredAt,
) error {
sql := `UPDATE subject_relation SET policy_expired_at = :policy_expired_at WHERE pk = :pk`
return database.SqlxBulkUpdateWithTx(tx, sql, relations)
}

// GetGroupMemberCountBeforeExpiredAt ...
func (m *subjectGroupManager) GetGroupMemberCountBeforeExpiredAt(
groupPK int64, expiredAt int64,
Expand Down
26 changes: 0 additions & 26 deletions pkg/database/dao/subject_group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,32 +160,6 @@ func Test_subjectRelationManager_GetMemberCountBeforeExpiredAt(t *testing.T) {
})
}

func Test_subjectRelationManager_UpdateExpiredAtWithTx(t *testing.T) {
database.RunWithMock(t, func(db *sqlx.DB, mock sqlmock.Sqlmock, t *testing.T) {
mock.ExpectBegin()
mock.ExpectPrepare(`^UPDATE subject_relation SET policy_expired_at = (.*) WHERE pk = (.*)`)
mock.ExpectExec(`^UPDATE subject_relation SET policy_expired_at =`).WithArgs(
int64(2), int64(1),
).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()

subjects := []SubjectRelationForUpdateExpiredAt{{
PK: 1,
ExpiredAt: 2,
}}

tx, err := db.Beginx()
assert.NoError(t, err)

manager := &subjectGroupManager{DB: db}
err = manager.UpdateExpiredAtWithTx(tx, subjects)

tx.Commit()

assert.NoError(t, err)
})
}

func Test_subjectRelationManager_BulkCreateWithTx(t *testing.T) {
database.RunWithMock(t, func(db *sqlx.DB, mock sqlmock.Sqlmock, t *testing.T) {
mock.ExpectBegin()
Expand Down
Loading

0 comments on commit 952efb8

Please sign in to comment.