Skip to content

Commit

Permalink
fix: update subject template group create delete
Browse files Browse the repository at this point in the history
  • Loading branch information
zhu327 committed Dec 11, 2023
1 parent a6ee3e1 commit db4b0c1
Show file tree
Hide file tree
Showing 9 changed files with 320 additions and 70 deletions.
15 changes: 15 additions & 0 deletions pkg/database/dao/mock/subject_template_group.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions pkg/database/dao/subject_template_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type SubjectTemplateGroupManager interface {
) (members []SubjectTemplateGroup, err error)
ListRelationBySubjectPKGroupPKs(subjectPK int64, groupPKs []int64) ([]SubjectTemplateGroup, error)
ListGroupDistinctSubjectPK(groupPK int64) (subjectPKs []int64, err error)
ListThinRelationWithMaxExpiredAtByGroupPK(groupPK int64) ([]ThinSubjectRelation, error)

BulkCreateWithTx(tx *sqlx.Tx, relations []SubjectTemplateGroup) error
BulkUpdateExpiredAtWithTx(tx *sqlx.Tx, relations []SubjectTemplateGroup) error
Expand Down Expand Up @@ -210,3 +211,27 @@ func (m *subjectTemplateGroupManager) ListGroupDistinctSubjectPK(groupPK int64)
}
return
}

func (m *subjectTemplateGroupManager) ListThinRelationWithMaxExpiredAtByGroupPK(
groupPK int64,
) ([]ThinSubjectRelation, error) {
relations := []ThinSubjectRelation{}

query := `SELECT
subject_pk,
MAX(expired_at) AS policy_expired_at
FROM subject_template_group
WHERE group_pk = ?
GROUP BY subject_pk`

err := database.SqlxSelect(m.DB, &relations, query, groupPK)
if errors.Is(err, sql.ErrNoRows) {
return relations, nil
}

for i := range relations {
relations[i].GroupPK = groupPK
}

return relations, err
}
22 changes: 22 additions & 0 deletions pkg/database/dao/subject_template_group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,25 @@ func Test_subjectTemplateGroupManager_BulkUpdateExpiredAtWithTx(t *testing.T) {
assert.NoError(t, err, "query from db fail.")
})
}

func Test_subjectTemplateGroupManager_ListThinRelationWithMaxExpiredAtByGroupPK(t *testing.T) {
database.RunWithMock(t, func(db *sqlx.DB, mock sqlmock.Sqlmock, t *testing.T) {
groupPK := int64(1)
mockQuery := `^SELECT subject_pk, (.*) FROM subject_template_group WHERE group_pk`

rows := sqlmock.NewRows([]string{"subject_pk", "policy_expired_at"}).
AddRow(int64(1), int64(1)).
AddRow(int64(2), int64(2))

mock.ExpectQuery(mockQuery).WithArgs(groupPK).WillReturnRows(rows)

manager := &subjectTemplateGroupManager{DB: db}
relations, err := manager.ListThinRelationWithMaxExpiredAtByGroupPK(groupPK)

assert.NoError(t, err, "query from db failed")
assert.Len(t, relations, 2, "did not get expected number of relations")
for _, rel := range relations {
assert.Equal(t, groupPK, rel.GroupPK, "GroupPK in relation does not match")
}
})
}
117 changes: 92 additions & 25 deletions pkg/service/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ package service
import (
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"time"

"github.com/TencentBlueKing/gopkg/collection/set"
Expand Down Expand Up @@ -441,7 +444,7 @@ func (l *groupService) UpdateGroupMembersExpiredAtWithTx(
}

for _, systemID := range systemIDs {
err = l.addOrUpdateSubjectSystemGroup(tx, m.SubjectPK, systemID, groupPK, m.ExpiredAt)
err = l.addOrUpdateSubjectSystemGroup(tx, m.SubjectPK, systemID, map[int64]int64{groupPK: m.ExpiredAt})
if err != nil {
return errorWrapf(
err,
Expand Down Expand Up @@ -528,7 +531,7 @@ func (l *groupService) BulkDeleteGroupMembers(
}

for _, systemID := range systemIDs {
err = l.removeSubjectSystemGroup(tx, subjectPK, systemID, groupPK)
err = l.removeSubjectSystemGroup(tx, subjectPK, systemID, map[int64]int64{groupPK: 0})
if err != nil {
return nil, errorWrapf(
err,
Expand Down Expand Up @@ -582,7 +585,7 @@ func (l *groupService) BulkCreateGroupMembersWithTx(
}

for _, systemID := range systemIDs {
err = l.addOrUpdateSubjectSystemGroup(tx, r.SubjectPK, systemID, groupPK, r.ExpiredAt)
err = l.addOrUpdateSubjectSystemGroup(tx, r.SubjectPK, systemID, map[int64]int64{groupPK: r.ExpiredAt})
if err != nil {
return errorWrapf(
err,
Expand Down Expand Up @@ -620,6 +623,41 @@ func (l *groupService) BulkCreateSubjectTemplateGroupWithTx(
return nil
}

type subjectSystemGroupHelper struct {
subjectSystemGroup map[string]map[int64]int64 // key: subjectPK:systemID, map: groupPK-expiredAt
}

// Add adds a group to the subjectSystemGroup map
func (h *subjectSystemGroupHelper) Add(subjectPK int64, systemID string, groupPK int64, expiredAt int64) {
key := h.generateKey(subjectPK, systemID)
if _, ok := h.subjectSystemGroup[key]; !ok {
h.subjectSystemGroup[key] = make(map[int64]int64)
}

h.subjectSystemGroup[key][groupPK] = expiredAt
}

// generateKey generates a key based on subjectPK and systemID
func (h *subjectSystemGroupHelper) generateKey(subjectPK int64, systemID string) string {
return fmt.Sprintf("%d:%s", subjectPK, systemID)
}

// ParseKey parses a key into subjectPK and systemID
func (h *subjectSystemGroupHelper) ParseKey(key string) (subjectPK int64, systemID string, err error) {
parts := strings.Split(key, ":")
if len(parts) != 2 {
return 0, "", fmt.Errorf("invalid key format")
}

subjectPK, err = strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return 0, "", err
}

systemID = parts[1]
return subjectPK, systemID, nil
}

func (l *groupService) BulkUpdateSubjectSystemGroupBySubjectTemplateGroupWithTx(
tx *sqlx.Tx,
relations []types.SubjectTemplateGroup,
Expand All @@ -629,6 +667,7 @@ func (l *groupService) BulkUpdateSubjectSystemGroupBySubjectTemplateGroupWithTx(
"BulkUpdateSubjectSystemGroupBySubjectTemplateGroupWithTx",
)

subjectSystemGroup := &subjectSystemGroupHelper{}
groupSystemIDCache := make(map[int64][]string)
for _, relation := range relations {
if !relation.NeedUpdate {
Expand All @@ -647,25 +686,38 @@ func (l *groupService) BulkUpdateSubjectSystemGroupBySubjectTemplateGroupWithTx(
}

for _, systemID := range systemIDs {
err := l.addOrUpdateSubjectSystemGroup(
tx,
subjectSystemGroup.Add(
relation.SubjectPK,
systemID,
relation.GroupPK,
relation.ExpiredAt,
)
if err != nil {
return errorWrapf(
err,
"addOrUpdateSubjectSystemGroup systemID=`%s`, subjectPK=`%d`, groupPK=`%d`, expiredAt=`%d`, fail",
systemID,
relation.SubjectPK,
relation.GroupPK,
relation.ExpiredAt,
)
}
}
}

for key, groups := range subjectSystemGroup.subjectSystemGroup {
subjectPK, systemID, err := subjectSystemGroup.ParseKey(key)
if err != nil {
return errorWrapf(err, "parseKey key=`%s` fail", key)
}

err = l.addOrUpdateSubjectSystemGroup(
tx,
subjectPK,
systemID,
groups,
)
if err != nil {
return errorWrapf(
err,
"addOrUpdateSubjectSystemGroup systemID=`%s`, subjectPK=`%d`, groups=`%v`, fail",
systemID,
subjectPK,
groups,
)
}
}

return nil
}

Expand Down Expand Up @@ -738,6 +790,7 @@ func (l *groupService) BulkDeleteSubjectTemplateGroupWithTx(
return errorWrapf(err, "subjectTemplateGroupManager.BulkDeleteWithTx relations=`%+v` fail", daoRelations)
}

subjectSystemGroup := &subjectSystemGroupHelper{}
groupSystemIDCache := make(map[int64][]string)
for _, relation := range relations {
if !relation.NeedUpdate {
Expand All @@ -756,16 +809,30 @@ func (l *groupService) BulkDeleteSubjectTemplateGroupWithTx(
}

for _, systemID := range systemIDs {
err = l.removeSubjectSystemGroup(tx, relation.SubjectPK, systemID, relation.GroupPK)
if err != nil {
return errorWrapf(
err,
"removeSubjectSystemGroup systemID=`%s`, subjectPK=`%d`, groupPK=`%d`, fail",
systemID,
relation.SubjectPK,
relation.GroupPK,
)
}
subjectSystemGroup.Add(relation.SubjectPK, systemID, relation.ExpiredAt, 0)
}
}

for key, groups := range subjectSystemGroup.subjectSystemGroup {
subjectPK, systemID, err := subjectSystemGroup.ParseKey(key)
if err != nil {
return errorWrapf(err, "parseKey key=`%s` fail", key)
}

err = l.removeSubjectSystemGroup(
tx,
subjectPK,
systemID,
groups,
)
if err != nil {
return errorWrapf(
err,
"removeSubjectSystemGroup systemID=`%s`, subjectPK=`%d`, groups=`%v`, fail",
systemID,
subjectPK,
groups,
)
}
}
return nil
Expand Down
60 changes: 52 additions & 8 deletions pkg/service/group_system_auth_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"errors"
"time"

"github.com/TencentBlueKing/gopkg/collection/set"
"github.com/TencentBlueKing/gopkg/errorx"
"github.com/jmoiron/sqlx"

Expand Down Expand Up @@ -53,12 +54,30 @@ func (s *groupService) AlterGroupAuthType(
return false, errorWrapf(err, "manager.ListGroupMember groupPK=`%d` fail", groupPK)
}

for _, member := range members {
err := s.removeSubjectSystemGroup(tx, member.SubjectPK, systemID, groupPK)
subjectSet := set.NewInt64Set()
for _, relation := range members {
subjectSet.Add(relation.SubjectPK)
}

// 查询用户组模版成员
relations, err := s.subjectTemplateGroupManager.ListThinRelationWithMaxExpiredAtByGroupPK(groupPK)
if err != nil {
return false, errorWrapf(
err,
"subjectTemplateGroupManager.ListThinRelationWithMaxExpiredAtByGroupPK groupPK=`%d` fail",
groupPK,
)
}
for _, relation := range relations {
subjectSet.Add(relation.SubjectPK)
}

for _, subjectPK := range subjectSet.ToSlice() {
err := s.removeSubjectSystemGroup(tx, subjectPK, systemID, map[int64]int64{groupPK: 0})
if err != nil {
return false, errorWrapf(
err, "removeSubjectSystemGroup member=`%d` systemID=`%s` groupPK=`%d` fail",
member.SubjectPK, systemID, groupPK,
subjectPK, systemID, groupPK,
)
}
}
Expand All @@ -81,20 +100,45 @@ func (s *groupService) AlterGroupAuthType(
return false, errorWrapf(err, "manager.ListGroupMember groupPK=`%d` fail", groupPK)
}

// 查询用户组模版成员
relations, err := s.subjectTemplateGroupManager.ListThinRelationWithMaxExpiredAtByGroupPK(groupPK)
if err != nil {
return false, errorWrapf(
err,
"subjectTemplateGroupManager.ListThinRelationWithMaxExpiredAtByGroupPK groupPK=`%d` fail",
groupPK,
)
}

nowTS := time.Now().Unix()
for _, member := range members {
// NOTE: subject system group表中只需要保持未过期的记录
if member.ExpiredAt < nowTS {
subjectExpiredAtMap := make(map[int64]int64, len(relations)+len(members))
for _, relation := range members {
if relation.ExpiredAt < nowTS {
continue
}

subjectExpiredAtMap[relation.SubjectPK] = relation.ExpiredAt
}

for _, relation := range relations {
if relation.ExpiredAt < nowTS {
continue
}

// 取过期时间大的
if relation.ExpiredAt > subjectExpiredAtMap[relation.SubjectPK] {
subjectExpiredAtMap[relation.SubjectPK] = relation.ExpiredAt
}
}

for subjectPK, expiredAt := range subjectExpiredAtMap {
err := s.addOrUpdateSubjectSystemGroup(
tx, member.SubjectPK, systemID, groupPK, member.ExpiredAt,
tx, subjectPK, systemID, map[int64]int64{groupPK: expiredAt},
)
if err != nil {
return false, errorWrapf(
err, "addOrUpdateSubjectSystemGroup member=`%d` systemID=`%s` groupPK=`%d` fail",
member.SubjectPK, systemID, groupPK,
subjectPK, systemID, groupPK,
)
}
}
Expand Down
10 changes: 8 additions & 2 deletions pkg/service/group_system_auth_type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,15 @@ var _ = Describe("GroupService", func() {
[]dao.SubjectRelation{}, nil,
).AnyTimes()

mockSubjectTemplateGroupManager := mock.NewMockSubjectTemplateGroupManager(ctl)
mockSubjectTemplateGroupManager.EXPECT().ListThinRelationWithMaxExpiredAtByGroupPK(int64(1)).Return(
[]dao.ThinSubjectRelation{}, nil,
)

manager := &groupService{
manager: mockSubjectRelationManger,
authTypeManger: mockGroupSystemAuthTypeManager,
manager: mockSubjectRelationManger,
authTypeManger: mockGroupSystemAuthTypeManager,
subjectTemplateGroupManager: mockSubjectTemplateGroupManager,
}

changed, err := manager.AlterGroupAuthType(nil, "test", 1, 0)
Expand Down
Loading

0 comments on commit db4b0c1

Please sign in to comment.