diff --git a/pkg/abac/pap/department.go b/pkg/abac/pap/department.go index 652de3e8..ff63f76a 100644 --- a/pkg/abac/pap/department.go +++ b/pkg/abac/pap/department.go @@ -61,20 +61,29 @@ func (c *departmentController) ListPaging(limit, offset int64) ([]SubjectDepartm pks = append(pks, svcSubjectDepartment.DepartmentPKs...) } + subjects, err := cacheimpls.BatchGetSubjectByPKs(pks) + if err != nil { + return nil, errorWrapf(err, "cacheimpls.BatchGetSubjectByPKs pks=`%v` fail", pks) + } + subjectMap := make(map[int64]types.Subject, len(pks)) - for _, pk := range pks { - subject, err := cacheimpls.GetSubjectByPK(pk) - if err != nil { - return nil, errorWrapf(err, "cacheimpls.GetSubjectByPK pk=`%d` fail", pk) - } - subjectMap[pk] = subject + for _, subject := range subjects { + subjectMap[subject.PK] = subject } subjectDepartments := make([]SubjectDepartment, 0, len(svcSubjectDepartments)) for _, svcSubjectDepartment := range svcSubjectDepartments { + if _, ok := subjectMap[svcSubjectDepartment.SubjectPK]; !ok { + continue + } + subjectID := subjectMap[svcSubjectDepartment.SubjectPK].ID departmentIDs := make([]string, 0, len(svcSubjectDepartment.DepartmentPKs)) for _, depPK := range svcSubjectDepartment.DepartmentPKs { + if _, ok := subjectMap[depPK]; !ok { + continue + } + departmentIDs = append(departmentIDs, subjectMap[depPK].ID) } diff --git a/pkg/abac/pap/department_test.go b/pkg/abac/pap/department_test.go index 73b8f86e..27322b82 100644 --- a/pkg/abac/pap/department_test.go +++ b/pkg/abac/pap/department_test.go @@ -59,27 +59,28 @@ var _ = Describe("DepartmentController", func() { }, nil, ).AnyTimes() - patches := gomonkey.ApplyFunc(cacheimpls.GetSubjectByPK, func(pk int64) (subject types.Subject, err error) { - switch pk { - case 1: - return types.Subject{ - ID: "1", - Type: "user", + patches := gomonkey.ApplyFunc( + cacheimpls.BatchGetSubjectByPKs, + func(pks []int64) (subjects []types.Subject, err error) { + return []types.Subject{ + { + PK: 1, + ID: "1", + Type: "user", + }, + { + PK: 2, + ID: "2", + Type: "department", + }, + { + PK: 3, + ID: "3", + Type: "department", + }, }, nil - case 2: - return types.Subject{ - ID: "2", - Type: "department", - }, nil - case 3: - return types.Subject{ - ID: "3", - Type: "department", - }, nil - } - - return types.Subject{}, nil - }) + }, + ) defer patches.Reset() manager := &departmentController{ diff --git a/pkg/abac/pap/group.go b/pkg/abac/pap/group.go index d915e2dd..1c3ffff0 100644 --- a/pkg/abac/pap/group.go +++ b/pkg/abac/pap/group.go @@ -151,17 +151,16 @@ func (c *groupController) FilterGroupsHasMemberBeforeExpiredAt(subjects []Subjec ) } - existGroups := make([]Subject, 0, len(existGroupPKs)) - for _, pk := range existGroupPKs { - subject, err := cacheimpls.GetSubjectByPK(pk) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - continue - } - - return nil, errorWrapf(err, "cacheimpls.GetSubjectByPK pk=`%d` fail", pk) - } + existSubjects, err := cacheimpls.BatchGetSubjectByPKs(existGroupPKs) + if err != nil { + return nil, errorWrapf( + err, "cacheimpls.BatchGetSubjectByPKs groupPKs=`%+v` fail", + existGroupPKs, + ) + } + existGroups := make([]Subject, 0, len(existGroupPKs)) + for _, subject := range existSubjects { existGroups = append(existGroups, Subject{ Type: subject.Type, ID: subject.ID, @@ -752,17 +751,13 @@ func (c *groupController) ListRbacGroupByActionResource( } func groupPKsToSubjects(groupPKs []int64) ([]Subject, error) { - groups := make([]Subject, 0, len(groupPKs)) - for _, pk := range groupPKs { - subject, err := cacheimpls.GetSubjectByPK(pk) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - continue - } - - return nil, fmt.Errorf("subject query fail, subjectPK=`%d`", pk) - } + subjects, err := cacheimpls.BatchGetSubjectByPKs(groupPKs) + if err != nil { + return nil, fmt.Errorf("cacheimpls.BatchGetSubjectByPKs fail, subjectPKs=`%v`", groupPKs) + } + groups := make([]Subject, 0, len(groupPKs)) + for _, subject := range subjects { groups = append(groups, Subject{ Type: subject.Type, ID: subject.ID, @@ -773,15 +768,26 @@ func groupPKsToSubjects(groupPKs []int64) ([]Subject, error) { } func convertToSubjectGroups(svcSubjectGroups []types.SubjectGroup) ([]SubjectGroup, error) { - groups := make([]SubjectGroup, 0, len(svcSubjectGroups)) + groupPKs := make([]int64, 0, len(svcSubjectGroups)) for _, m := range svcSubjectGroups { - subject, err := cacheimpls.GetSubjectByPK(m.GroupPK) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - continue - } + groupPKs = append(groupPKs, m.GroupPK) + } - return nil, err + subjects, err := cacheimpls.BatchGetSubjectByPKs(groupPKs) + if err != nil { + return nil, err + } + + subjectMap := make(map[int64]types.Subject, len(subjects)) + for _, subject := range subjects { + subjectMap[subject.PK] = subject + } + + groups := make([]SubjectGroup, 0, len(svcSubjectGroups)) + for _, m := range svcSubjectGroups { + subject, ok := subjectMap[m.GroupPK] + if !ok { + continue } groups = append(groups, SubjectGroup{ @@ -832,24 +838,32 @@ func convertToGroupMembers(svcGroupMembers []types.GroupMember) ([]GroupMember, } func convertToGroupSubjects(svcGroupSubjects []types.GroupSubject) ([]GroupSubject, error) { - groupSubjects := make([]GroupSubject, 0, len(svcGroupSubjects)) + subjectPKs := set.NewInt64Set() for _, m := range svcGroupSubjects { - subject, err := cacheimpls.GetSubjectByPK(m.SubjectPK) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - continue - } + subjectPKs.Add(m.SubjectPK) + subjectPKs.Add(m.GroupPK) + } - return nil, err - } + subjects, err := cacheimpls.BatchGetSubjectByPKs(subjectPKs.ToSlice()) + if err != nil { + return nil, err + } - group, err := cacheimpls.GetSubjectByPK(m.GroupPK) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - continue - } + subjectMap := make(map[int64]types.Subject, len(subjects)) + for _, subject := range subjects { + subjectMap[subject.PK] = subject + } - return nil, err + groupSubjects := make([]GroupSubject, 0, len(svcGroupSubjects)) + for _, m := range svcGroupSubjects { + subject, ok := subjectMap[m.SubjectPK] + if !ok { + continue + } + + group, ok := subjectMap[m.GroupPK] + if !ok { + continue } groupSubjects = append(groupSubjects, GroupSubject{ diff --git a/pkg/abac/pap/group_test.go b/pkg/abac/pap/group_test.go index 13adc5a3..1e7ca378 100644 --- a/pkg/abac/pap/group_test.go +++ b/pkg/abac/pap/group_test.go @@ -576,8 +576,8 @@ var _ = Describe("GroupController", func() { ). AnyTimes() - patches.ApplyFunc(cacheimpls.GetSubjectByPK, func(pk int64) (subject types.Subject, err error) { - return types.Subject{}, errors.New("err") + patches.ApplyFunc(cacheimpls.BatchGetSubjectByPKs, func(pks []int64) (subjects []types.Subject, err error) { + return nil, errors.New("err") }) c := &groupController{ @@ -615,8 +615,8 @@ var _ = Describe("GroupController", func() { ). AnyTimes() - patches.ApplyFunc(cacheimpls.GetSubjectByPK, func(pk int64) (subject types.Subject, err error) { - return types.Subject{}, nil + patches.ApplyFunc(cacheimpls.BatchGetSubjectByPKs, func(pks []int64) (subjects []types.Subject, err error) { + return []types.Subject{{}}, nil }) c := &groupController{ @@ -778,8 +778,8 @@ var _ = Describe("GroupController", func() { return []int64{1}, nil }) - patches.ApplyFunc(cacheimpls.GetSubjectByPK, func(pk int64) (subject types.Subject, err error) { - return types.Subject{}, errors.New("err") + patches.ApplyFunc(cacheimpls.BatchGetSubjectByPKs, func(pks []int64) (subjects []types.Subject, err error) { + return nil, errors.New("err") }) c := &groupController{} @@ -815,8 +815,8 @@ var _ = Describe("GroupController", func() { return []int64{1}, nil }) - patches.ApplyFunc(cacheimpls.GetSubjectByPK, func(pk int64) (subject types.Subject, err error) { - return types.Subject{}, nil + patches.ApplyFunc(cacheimpls.BatchGetSubjectByPKs, func(pks []int64) (subjects []types.Subject, err error) { + return []types.Subject{{}}, nil }) c := &groupController{}