diff --git a/changelog/unreleased/user-rest-refactor.md b/changelog/unreleased/user-rest-refactor.md new file mode 100644 index 0000000000..2543bbd504 --- /dev/null +++ b/changelog/unreleased/user-rest-refactor.md @@ -0,0 +1,7 @@ +Enhancement: Refactor the rest user and group provider drivers + +We now maintain our own cache for all user and group data, and periodically +refresh it. A redis server now becomes a necessary dependency, whereas it was +optional previously. + +https://github.com/cs3org/reva/pull/2752 \ No newline at end of file diff --git a/internal/http/services/ocmd/ocmd.go b/internal/http/services/ocmd/ocmd.go index 776cf36414..c65cec5d67 100644 --- a/internal/http/services/ocmd/ocmd.go +++ b/internal/http/services/ocmd/ocmd.go @@ -47,9 +47,9 @@ type Config struct { func (c *Config) init() { c.GatewaySvc = sharedconf.GetGatewaySVC(c.GatewaySvc) - // if c.Prefix == "" { - // c.Prefix = "ocm" - // } + if c.Prefix == "" { + c.Prefix = "ocm" + } } type svc struct { diff --git a/pkg/cbox/group/rest/cache.go b/pkg/cbox/group/rest/cache.go index 80da1d4218..cd4110fd65 100644 --- a/pkg/cbox/group/rest/cache.go +++ b/pkg/cbox/group/rest/cache.go @@ -21,7 +21,9 @@ package rest import ( "encoding/json" "errors" + "fmt" "strconv" + "strings" "time" grouppb "github.com/cs3org/go-cs3apis/cs3/identity/group/v1beta1" @@ -31,6 +33,9 @@ import ( const ( groupPrefix = "group:" + idPrefix = "id:" + namePrefix = "name:" + gidPrefix = "gid:" groupMembersPrefix = "members:" groupInternalIDPrefix = "internal:" ) @@ -76,14 +81,12 @@ func (m *manager) setVal(key, val string, expiration int) error { conn := m.redisPool.Get() defer conn.Close() if conn != nil { + args := []interface{}{key, val} if expiration != -1 { - if _, err := conn.Do("SET", key, val, "EX", expiration); err != nil { - return err - } - } else { - if _, err := conn.Do("SET", key, val); err != nil { - return err - } + args = append(args, "EX", expiration) + } + if _, err := conn.Do("SET", args...); err != nil { + return err } return nil } @@ -111,8 +114,46 @@ func (m *manager) cacheInternalID(gid *grouppb.GroupId, internalID string) error return m.setVal(groupPrefix+groupInternalIDPrefix+gid.OpaqueId, internalID, -1) } +func (m *manager) findCachedGroups(query string) ([]*grouppb.Group, error) { + conn := m.redisPool.Get() + defer conn.Close() + if conn != nil { + query = fmt.Sprintf("%s*%s*", groupPrefix, strings.ReplaceAll(strings.ToLower(query), " ", "_")) + keys, err := redis.Strings(conn.Do("KEYS", query)) + if err != nil { + return nil, err + } + var args []interface{} + for _, k := range keys { + args = append(args, k) + } + + // Fetch the groups for all these keys + groupStrings, err := redis.Strings(conn.Do("MGET", args...)) + if err != nil { + return nil, err + } + groupMap := make(map[string]*grouppb.Group) + for _, group := range groupStrings { + g := grouppb.Group{} + if err = json.Unmarshal([]byte(group), &g); err == nil { + groupMap[g.Id.OpaqueId] = &g + } + } + + var groups []*grouppb.Group + for _, g := range groupMap { + groups = append(groups, g) + } + + return groups, nil + } + + return nil, errors.New("rest: unable to get connection from redis pool") +} + func (m *manager) fetchCachedGroupDetails(gid *grouppb.GroupId) (*grouppb.Group, error) { - group, err := m.getVal(groupPrefix + gid.OpaqueId) + group, err := m.getVal(groupPrefix + idPrefix + gid.OpaqueId) if err != nil { return nil, err } @@ -129,28 +170,38 @@ func (m *manager) cacheGroupDetails(g *grouppb.Group) error { if err != nil { return err } - if err = m.setVal(groupPrefix+g.Id.OpaqueId, string(encodedGroup), -1); err != nil { + if err = m.setVal(groupPrefix+idPrefix+strings.ToLower(g.Id.OpaqueId), string(encodedGroup), -1); err != nil { return err } - if err = m.setVal(groupPrefix+"gid_number:"+strconv.FormatInt(g.GidNumber, 10), g.Id.OpaqueId, -1); err != nil { - return err - } - if err = m.setVal(groupPrefix+"mail:"+g.Mail, g.Id.OpaqueId, -1); err != nil { - return err + if g.GidNumber != 0 { + if err = m.setVal(groupPrefix+gidPrefix+strconv.FormatInt(g.GidNumber, 10), g.Id.OpaqueId, -1); err != nil { + return err + } } - if err = m.setVal(groupPrefix+"group_name:"+g.GroupName, g.Id.OpaqueId, -1); err != nil { - return err + if g.DisplayName != "" { + if err = m.setVal(groupPrefix+namePrefix+g.Id.OpaqueId+"_"+strings.ToLower(g.DisplayName), g.Id.OpaqueId, -1); err != nil { + return err + } } return nil } -func (m *manager) fetchCachedParam(field, claim string) (string, error) { - return m.getVal(groupPrefix + field + ":" + claim) +func (m *manager) fetchCachedGroupByParam(field, claim string) (*grouppb.Group, error) { + group, err := m.getVal(groupPrefix + field + ":" + strings.ToLower(claim)) + if err != nil { + return nil, err + } + + g := grouppb.Group{} + if err = json.Unmarshal([]byte(group), &g); err != nil { + return nil, err + } + return &g, nil } func (m *manager) fetchCachedGroupMembers(gid *grouppb.GroupId) ([]*userpb.UserId, error) { - members, err := m.getVal(groupPrefix + groupMembersPrefix + gid.OpaqueId) + members, err := m.getVal(groupPrefix + groupMembersPrefix + strings.ToLower(gid.OpaqueId)) if err != nil { return nil, err } @@ -166,8 +217,5 @@ func (m *manager) cacheGroupMembers(gid *grouppb.GroupId, members []*userpb.User if err != nil { return err } - if err = m.setVal(groupPrefix+groupMembersPrefix+gid.OpaqueId, string(u), m.conf.GroupMembersCacheExpiration*60); err != nil { - return err - } - return nil + return m.setVal(groupPrefix+groupMembersPrefix+strings.ToLower(gid.OpaqueId), string(u), m.conf.GroupMembersCacheExpiration*60) } diff --git a/pkg/cbox/group/rest/rest.go b/pkg/cbox/group/rest/rest.go index 3852b68027..9c14930fcb 100644 --- a/pkg/cbox/group/rest/rest.go +++ b/pkg/cbox/group/rest/rest.go @@ -22,9 +22,11 @@ import ( "context" "errors" "fmt" - "net/url" - "regexp" + "os" + "os/signal" "strings" + "syscall" + "time" grouppb "github.com/cs3org/go-cs3apis/cs3/identity/group/v1beta1" userpb "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1" @@ -34,16 +36,13 @@ import ( "github.com/cs3org/reva/pkg/group/manager/registry" "github.com/gomodule/redigo/redis" "github.com/mitchellh/mapstructure" + "github.com/rs/zerolog/log" ) func init() { registry.Register("rest", New) } -var ( - emailRegex = regexp.MustCompile(`^[\w-\.]+@([\w-]+\.)+[\w-]{2,4}$`) -) - type manager struct { conf *config redisPool *redis.Pool @@ -62,7 +61,7 @@ type config struct { // The OIDC Provider IDProvider string `mapstructure:"id_provider" docs:"http://cernbox.cern.ch"` // Base API Endpoint - APIBaseURL string `mapstructure:"api_base_url" docs:"https://authorization-service-api-dev.web.cern.ch/api/v1.0"` + APIBaseURL string `mapstructure:"api_base_url" docs:"https://authorization-service-api-dev.web.cern.ch"` // Client ID needed to authenticate ClientID string `mapstructure:"client_id" docs:"-"` // Client Secret @@ -72,6 +71,8 @@ type config struct { OIDCTokenEndpoint string `mapstructure:"oidc_token_endpoint" docs:"https://keycloak-dev.cern.ch/auth/realms/cern/api-access/token"` // The target application for which token needs to be generated TargetAPI string `mapstructure:"target_api" docs:"authorization-service-api"` + // The time in seconds between bulk fetch of groups + GroupFetchInterval int `mapstructure:"group_fetch_interval" docs:"3600"` } func (c *config) init() { @@ -82,7 +83,7 @@ func (c *config) init() { c.RedisAddress = ":6379" } if c.APIBaseURL == "" { - c.APIBaseURL = "https://authorization-service-api-dev.web.cern.ch/api/v1.0" + c.APIBaseURL = "https://authorization-service-api-dev.web.cern.ch" } if c.TargetAPI == "" { c.TargetAPI = "authorization-service-api" @@ -93,6 +94,9 @@ func (c *config) init() { if c.IDProvider == "" { c.IDProvider = "http://cernbox.cern.ch" } + if c.GroupFetchInterval == 0 { + c.GroupFetchInterval = 3600 + } } func parseConfig(m map[string]interface{}) (*config, error) { @@ -113,57 +117,78 @@ func New(m map[string]interface{}) (group.Manager, error) { redisPool := initRedisPool(c.RedisAddress, c.RedisUsername, c.RedisPassword) apiTokenManager := utils.InitAPITokenManager(c.TargetAPI, c.OIDCTokenEndpoint, c.ClientID, c.ClientSecret) - return &manager{ + + mgr := &manager{ conf: c, redisPool: redisPool, apiTokenManager: apiTokenManager, - }, nil -} - -func (m *manager) getGroupByParam(ctx context.Context, param, val string) (map[string]interface{}, error) { - url := fmt.Sprintf("%s/Group?filter=%s:%s&field=groupIdentifier&field=displayName&field=gid", - m.conf.APIBaseURL, param, val) - responseData, err := m.apiTokenManager.SendAPIGetRequest(ctx, url, false) - if err != nil { - return nil, err - } - if len(responseData) != 1 { - return nil, errors.New("rest: group not found: " + param + ": " + val) } + go mgr.fetchAllGroups() + return mgr, nil +} - userData, ok := responseData[0].(map[string]interface{}) - if !ok { - return nil, errors.New("rest: error in type assertion") +func (m *manager) fetchAllGroups() { + _ = m.fetchAllGroupAccounts() + ticker := time.NewTicker(time.Duration(m.conf.GroupFetchInterval) * time.Second) + work := make(chan os.Signal, 1) + signal.Notify(work, syscall.SIGHUP, syscall.SIGINT, syscall.SIGQUIT) + + for { + select { + case <-work: + return + case <-ticker.C: + _ = m.fetchAllGroupAccounts() + } } - return userData, nil } -func (m *manager) getInternalGroupID(ctx context.Context, gid *grouppb.GroupId) (string, error) { +func (m *manager) fetchAllGroupAccounts() error { + ctx := context.Background() + url := fmt.Sprintf("%s/api/v1.0/Group?field=groupIdentifier&field=displayName&field=gid", m.conf.APIBaseURL) - internalID, err := m.fetchCachedInternalID(gid) - if err != nil { - groupData, err := m.getGroupByParam(ctx, "groupIdentifier", gid.OpaqueId) + for url != "" { + result, err := m.apiTokenManager.SendAPIGetRequest(ctx, url, false) if err != nil { - return "", err + return err } - id, ok := groupData["id"].(string) + + responseData, ok := result["data"].([]interface{}) if !ok { - return "", errors.New("rest: error in type assertion") + return errors.New("rest: error in type assertion") } + for _, usr := range responseData { + groupData, ok := usr.(map[string]interface{}) + if !ok { + continue + } - if err = m.cacheInternalID(gid, id); err != nil { - log := appctx.GetLogger(ctx) - log.Error().Err(err).Msg("rest: error caching group details") + _, err = m.parseAndCacheGroup(ctx, groupData) + if err != nil { + continue + } + } + + url = "" + if pagination, ok := result["pagination"].(map[string]interface{}); ok { + if links, ok := pagination["links"].(map[string]interface{}); ok { + if next, ok := links["next"].(string); ok { + url = fmt.Sprintf("%s%s", m.conf.APIBaseURL, next) + } + } } - return id, nil } - return internalID, nil + + return nil } -func (m *manager) parseAndCacheGroup(ctx context.Context, groupData map[string]interface{}) *grouppb.Group { - id, _ := groupData["groupIdentifier"].(string) - name, _ := groupData["displayName"].(string) +func (m *manager) parseAndCacheGroup(ctx context.Context, groupData map[string]interface{}) (*grouppb.Group, error) { + id, ok := groupData["groupIdentifier"].(string) + if !ok { + return nil, errors.New("rest: missing upn in user data") + } + name, _ := groupData["displayName"].(string) groupID := &grouppb.GroupId{ OpaqueId: id, Idp: m.conf.IDProvider, @@ -181,25 +206,23 @@ func (m *manager) parseAndCacheGroup(ctx context.Context, groupData map[string]i } if err := m.cacheGroupDetails(g); err != nil { - log := appctx.GetLogger(ctx) log.Error().Err(err).Msg("rest: error caching group details") } - if err := m.cacheInternalID(groupID, groupData["id"].(string)); err != nil { - log := appctx.GetLogger(ctx) - log.Error().Err(err).Msg("rest: error caching group details") + + if internalID, ok := groupData["id"].(string); ok { + if err := m.cacheInternalID(groupID, internalID); err != nil { + log.Error().Err(err).Msg("rest: error caching group details") + } } - return g + + return g, nil } func (m *manager) GetGroup(ctx context.Context, gid *grouppb.GroupId, skipFetchingMembers bool) (*grouppb.Group, error) { g, err := m.fetchCachedGroupDetails(gid) if err != nil { - groupData, err := m.getGroupByParam(ctx, "groupIdentifier", gid.OpaqueId) - if err != nil { - return nil, err - } - g = m.parseAndCacheGroup(ctx, groupData) + return nil, err } if !skipFetchingMembers { @@ -214,29 +237,14 @@ func (m *manager) GetGroup(ctx context.Context, gid *grouppb.GroupId, skipFetchi } func (m *manager) GetGroupByClaim(ctx context.Context, claim, value string, skipFetchingMembers bool) (*grouppb.Group, error) { - value = url.QueryEscape(value) - opaqueID, err := m.fetchCachedParam(claim, value) - if err == nil { - return m.GetGroup(ctx, &grouppb.GroupId{OpaqueId: opaqueID}, skipFetchingMembers) - } - - switch claim { - case "mail": - claim = "groupIdentifier" - value = strings.TrimSuffix(value, "@cern.ch") - case "gid_number": - claim = "gid" - case "group_name": - claim = "groupIdentifier" - default: - return nil, errors.New("rest: invalid field: " + claim) + if claim == "group_name" { + return m.GetGroup(ctx, &grouppb.GroupId{OpaqueId: value}, skipFetchingMembers) } - groupData, err := m.getGroupByParam(ctx, claim, value) + g, err := m.fetchCachedGroupByParam(claim, value) if err != nil { return nil, err } - g := m.parseAndCacheGroup(ctx, groupData) if !skipFetchingMembers { groupMembers, err := m.GetMembers(ctx, g.Id) @@ -247,52 +255,6 @@ func (m *manager) GetGroupByClaim(ctx context.Context, claim, value string, skip } return g, nil - -} - -func (m *manager) findGroupsByFilter(ctx context.Context, url string, groups map[string]*grouppb.Group, skipFetchingMembers bool) error { - - groupData, err := m.apiTokenManager.SendAPIGetRequest(ctx, url, false) - if err != nil { - return err - } - - for _, grp := range groupData { - grpInfo, ok := grp.(map[string]interface{}) - if !ok { - continue - } - id, _ := grpInfo["groupIdentifier"].(string) - name, _ := grpInfo["displayName"].(string) - - groupID := &grouppb.GroupId{ - OpaqueId: id, - Idp: m.conf.IDProvider, - } - - var groupMembers []*userpb.UserId - if !skipFetchingMembers { - groupMembers, err = m.GetMembers(ctx, groupID) - if err != nil { - return err - } - } - gid, ok := grpInfo["gid"].(int64) - if !ok { - gid = 0 - } - - groups[groupID.OpaqueId] = &grouppb.Group{ - Id: groupID, - GroupName: id, - Mail: id + "@cern.ch", - DisplayName: name, - GidNumber: gid, - Members: groupMembers, - } - } - - return nil } func (m *manager) FindGroups(ctx context.Context, query string, skipFetchingMembers bool) ([]*grouppb.Group, error) { @@ -311,29 +273,7 @@ func (m *manager) FindGroups(ctx context.Context, query string, skipFetchingMemb } } - filters := []string{"groupIdentifier"} - if emailRegex.MatchString(query) { - parts := strings.Split(query, "@") - query = parts[0] - } - - groups := make(map[string]*grouppb.Group) - - for _, f := range filters { - url := fmt.Sprintf("%s/Group/?filter=%s:contains:%s&field=groupIdentifier&field=displayName&field=gid", - m.conf.APIBaseURL, f, url.QueryEscape(query)) - err := m.findGroupsByFilter(ctx, url, groups, skipFetchingMembers) - if err != nil { - return nil, err - } - } - - groupSlice := []*grouppb.Group{} - for _, v := range groups { - groupSlice = append(groupSlice, v) - } - - return groupSlice, nil + return m.findCachedGroups(query) } func (m *manager) GetMembers(ctx context.Context, gid *grouppb.GroupId) ([]*userpb.UserId, error) { @@ -343,16 +283,17 @@ func (m *manager) GetMembers(ctx context.Context, gid *grouppb.GroupId) ([]*user return users, nil } - internalID, err := m.getInternalGroupID(ctx, gid) + internalID, err := m.fetchCachedInternalID(gid) if err != nil { return nil, err } - url := fmt.Sprintf("%s/Group/%s/memberidentities/precomputed", m.conf.APIBaseURL, internalID) - userData, err := m.apiTokenManager.SendAPIGetRequest(ctx, url, false) + url := fmt.Sprintf("%s/api/v1.0/Group/%s/memberidentities/precomputed", m.conf.APIBaseURL, internalID) + result, err := m.apiTokenManager.SendAPIGetRequest(ctx, url, false) if err != nil { return nil, err } + userData := result["data"].([]interface{}) users = []*userpb.UserId{} for _, u := range userData { diff --git a/pkg/cbox/user/rest/cache.go b/pkg/cbox/user/rest/cache.go index 5361f0c164..7f855045dd 100644 --- a/pkg/cbox/user/rest/cache.go +++ b/pkg/cbox/user/rest/cache.go @@ -21,6 +21,9 @@ package rest import ( "encoding/json" "errors" + "fmt" + "strconv" + "strings" "time" userpb "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1" @@ -28,9 +31,12 @@ import ( ) const ( - userPrefix = "user:" - userGroupsPrefix = "groups:" - userInternalIDPrefix = "internal:" + userPrefix = "user:" + usernamePrefix = "username:" + namePrefix = "name:" + mailPrefix = "mail:" + uidPrefix = "uid:" + userGroupsPrefix = "groups:" ) func initRedisPool(address, username, password string) *redis.Pool { @@ -67,14 +73,12 @@ func (m *manager) setVal(key, val string, expiration int) error { conn := m.redisPool.Get() defer conn.Close() if conn != nil { + args := []interface{}{key, val} if expiration != -1 { - if _, err := conn.Do("SET", key, val, "EX", expiration); err != nil { - return err - } - } else { - if _, err := conn.Do("SET", key, val); err != nil { - return err - } + args = append(args, "EX", expiration) + } + if _, err := conn.Do("SET", args...); err != nil { + return err } return nil } @@ -94,16 +98,46 @@ func (m *manager) getVal(key string) (string, error) { return "", errors.New("rest: unable to get connection from redis pool") } -func (m *manager) fetchCachedInternalID(uid *userpb.UserId) (string, error) { - return m.getVal(userPrefix + userInternalIDPrefix + uid.OpaqueId) -} +func (m *manager) findCachedUsers(query string) ([]*userpb.User, error) { + conn := m.redisPool.Get() + defer conn.Close() + if conn != nil { + query = fmt.Sprintf("%s*%s*", userPrefix, strings.ReplaceAll(strings.ToLower(query), " ", "_")) + keys, err := redis.Strings(conn.Do("KEYS", query)) + if err != nil { + return nil, err + } + var args []interface{} + for _, k := range keys { + args = append(args, k) + } + + // Fetch the users for all these keys + userStrings, err := redis.Strings(conn.Do("MGET", args...)) + if err != nil { + return nil, err + } + userMap := make(map[string]*userpb.User) + for _, user := range userStrings { + u := userpb.User{} + if err = json.Unmarshal([]byte(user), &u); err == nil { + userMap[u.Id.OpaqueId] = &u + } + } + + var users []*userpb.User + for _, u := range userMap { + users = append(users, u) + } + + return users, nil + } -func (m *manager) cacheInternalID(uid *userpb.UserId, internalID string) error { - return m.setVal(userPrefix+userInternalIDPrefix+uid.OpaqueId, internalID, -1) + return nil, errors.New("rest: unable to get connection from redis pool") } func (m *manager) fetchCachedUserDetails(uid *userpb.UserId) (*userpb.User, error) { - user, err := m.getVal(userPrefix + uid.OpaqueId) + user, err := m.getVal(userPrefix + usernamePrefix + strings.ToLower(uid.OpaqueId)) if err != nil { return nil, err } @@ -120,25 +154,43 @@ func (m *manager) cacheUserDetails(u *userpb.User) error { if err != nil { return err } - if err = m.setVal(userPrefix+u.Id.OpaqueId, string(encodedUser), -1); err != nil { + if err = m.setVal(userPrefix+usernamePrefix+strings.ToLower(u.Id.OpaqueId), string(encodedUser), -1); err != nil { return err } - uid, err := extractUID(u) - if err == nil { - _ = m.setVal(userPrefix+"uid:"+uid, u.Id.OpaqueId, -1) + if u.Mail != "" { + if err = m.setVal(userPrefix+mailPrefix+strings.ToLower(u.Mail), string(encodedUser), -1); err != nil { + return err + } + } + if u.DisplayName != "" { + if err = m.setVal(userPrefix+namePrefix+u.Id.OpaqueId+"_"+strings.ReplaceAll(strings.ToLower(u.DisplayName), " ", "_"), string(encodedUser), -1); err != nil { + return err + } + } + if u.UidNumber != 0 { + if err = m.setVal(userPrefix+uidPrefix+strconv.FormatInt(u.UidNumber, 10), string(encodedUser), -1); err != nil { + return err + } } - _ = m.setVal(userPrefix+"mail:"+u.Mail, u.Id.OpaqueId, -1) - _ = m.setVal(userPrefix+"username:"+u.Username, u.Id.OpaqueId, -1) return nil } -func (m *manager) fetchCachedParam(field, claim string) (string, error) { - return m.getVal(userPrefix + field + ":" + claim) +func (m *manager) fetchCachedUserByParam(field, claim string) (*userpb.User, error) { + user, err := m.getVal(userPrefix + field + ":" + strings.ToLower(claim)) + if err != nil { + return nil, err + } + + u := userpb.User{} + if err = json.Unmarshal([]byte(user), &u); err != nil { + return nil, err + } + return &u, nil } func (m *manager) fetchCachedUserGroups(uid *userpb.UserId) ([]string, error) { - groups, err := m.getVal(userPrefix + userGroupsPrefix + uid.OpaqueId) + groups, err := m.getVal(userPrefix + userGroupsPrefix + strings.ToLower(uid.OpaqueId)) if err != nil { return nil, err } @@ -154,8 +206,5 @@ func (m *manager) cacheUserGroups(uid *userpb.UserId, groups []string) error { if err != nil { return err } - if err = m.setVal(userPrefix+userGroupsPrefix+uid.OpaqueId, string(g), m.conf.UserGroupsCacheExpiration*60); err != nil { - return err - } - return nil + return m.setVal(userPrefix+userGroupsPrefix+strings.ToLower(uid.OpaqueId), string(g), m.conf.UserGroupsCacheExpiration*60) } diff --git a/pkg/cbox/user/rest/rest.go b/pkg/cbox/user/rest/rest.go index 78ec0ffbbb..0c77c95c8a 100644 --- a/pkg/cbox/user/rest/rest.go +++ b/pkg/cbox/user/rest/rest.go @@ -21,10 +21,11 @@ package rest import ( "context" "fmt" - "net/url" - "regexp" - "strconv" + "os" + "os/signal" "strings" + "syscall" + "time" userpb "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1" "github.com/cs3org/reva/pkg/appctx" @@ -34,17 +35,13 @@ import ( "github.com/gomodule/redigo/redis" "github.com/mitchellh/mapstructure" "github.com/pkg/errors" + "github.com/rs/zerolog/log" ) func init() { registry.Register("rest", New) } -var ( - emailRegex = regexp.MustCompile(`^[\w-\.]+@([\w-]+\.)+[\w-]{2,4}$`) - usernameRegex = regexp.MustCompile(`^[ a-zA-Z0-9._-]+$`) -) - type manager struct { conf *config redisPool *redis.Pool @@ -63,7 +60,7 @@ type config struct { // The OIDC Provider IDProvider string `mapstructure:"id_provider" docs:"http://cernbox.cern.ch"` // Base API Endpoint - APIBaseURL string `mapstructure:"api_base_url" docs:"https://authorization-service-api-dev.web.cern.ch/api/v1.0"` + APIBaseURL string `mapstructure:"api_base_url" docs:"https://authorization-service-api-dev.web.cern.ch"` // Client ID needed to authenticate ClientID string `mapstructure:"client_id" docs:"-"` // Client Secret @@ -73,6 +70,8 @@ type config struct { OIDCTokenEndpoint string `mapstructure:"oidc_token_endpoint" docs:"https://keycloak-dev.cern.ch/auth/realms/cern/api-access/token"` // The target application for which token needs to be generated TargetAPI string `mapstructure:"target_api" docs:"authorization-service-api"` + // The time in seconds between bulk fetch of user accounts + UserFetchInterval int `mapstructure:"user_fetch_interval" docs:"3600"` } func (c *config) init() { @@ -83,7 +82,7 @@ func (c *config) init() { c.RedisAddress = ":6379" } if c.APIBaseURL == "" { - c.APIBaseURL = "https://authorization-service-api-dev.web.cern.ch/api/v1.0" + c.APIBaseURL = "https://authorization-service-api-dev.web.cern.ch" } if c.TargetAPI == "" { c.TargetAPI = "authorization-service-api" @@ -94,6 +93,9 @@ func (c *config) init() { if c.IDProvider == "" { c.IDProvider = "http://cernbox.cern.ch" } + if c.UserFetchInterval == 0 { + c.UserFetchInterval = 3600 + } } func parseConfig(m map[string]interface{}) (*config, error) { @@ -125,87 +127,73 @@ func (m *manager) Configure(ml map[string]interface{}) error { m.conf = c m.redisPool = redisPool m.apiTokenManager = apiTokenManager + + // Since we're starting a subroutine which would take some time to execute, + // we can't wait to see if it works before returning the user.Manager object + // TODO: return err if the fetch fails + go m.fetchAllUsers() return nil } -func (m *manager) getUser(ctx context.Context, url string) (map[string]interface{}, error) { - responseData, err := m.apiTokenManager.SendAPIGetRequest(ctx, url, false) - if err != nil { - return nil, err - } - - var users []map[string]interface{} - for _, usr := range responseData { - userData, ok := usr.(map[string]interface{}) - if !ok { - continue - } - - t, _ := userData["type"].(string) - userType := getUserType(t, userData["upn"].(string)) - if userType != userpb.UserType_USER_TYPE_APPLICATION { - users = append(users, userData) +func (m *manager) fetchAllUsers() { + _ = m.fetchAllUserAccounts() + ticker := time.NewTicker(time.Duration(m.conf.UserFetchInterval) * time.Second) + work := make(chan os.Signal, 1) + signal.Notify(work, syscall.SIGHUP, syscall.SIGINT, syscall.SIGQUIT) + + for { + select { + case <-work: + return + case <-ticker.C: + _ = m.fetchAllUserAccounts() } } - - if len(users) != 1 { - return nil, errors.New("rest: user not found for URL: " + url) - } - - return users[0], nil -} - -func (m *manager) getUserByParam(ctx context.Context, param, val string) (map[string]interface{}, error) { - url := fmt.Sprintf("%s/Identity?filter=%s:%s&field=upn&field=primaryAccountEmail&field=displayName&field=uid&field=gid&field=type", - m.conf.APIBaseURL, param, url.QueryEscape(val)) - return m.getUser(ctx, url) -} - -func (m *manager) getLightweightUser(ctx context.Context, mail string) (map[string]interface{}, error) { - url := fmt.Sprintf("%s/Identity?filter=primaryAccountEmail:%s&filter=upn:contains:guest&field=upn&field=primaryAccountEmail&field=displayName&field=uid&field=gid&field=type", - m.conf.APIBaseURL, url.QueryEscape(mail)) - return m.getUser(ctx, url) } -func (m *manager) getInternalUserID(ctx context.Context, uid *userpb.UserId) (string, error) { +func (m *manager) fetchAllUserAccounts() error { + ctx := context.Background() + url := fmt.Sprintf("%s/api/v1.0/Identity?field=upn&field=primaryAccountEmail&field=displayName&field=uid&field=gid&field=type", m.conf.APIBaseURL) - internalID, err := m.fetchCachedInternalID(uid) - if err != nil { - var ( - userData map[string]interface{} - err error - ) - if uid.Type == userpb.UserType_USER_TYPE_LIGHTWEIGHT { - // Lightweight accounts need to be fetched by email - userData, err = m.getLightweightUser(ctx, strings.TrimPrefix(uid.OpaqueId, "guest:")) - } else { - userData, err = m.getUserByParam(ctx, "upn", uid.OpaqueId) - } + for url != "" { + result, err := m.apiTokenManager.SendAPIGetRequest(ctx, url, false) if err != nil { - return "", err + return err } - id, ok := userData["id"].(string) + + responseData, ok := result["data"].([]interface{}) if !ok { - return "", errors.New("rest: error in type assertion") + return errors.New("rest: error in type assertion") + } + for _, usr := range responseData { + userData, ok := usr.(map[string]interface{}) + if !ok { + continue + } + + _, err = m.parseAndCacheUser(ctx, userData) + if err != nil { + continue + } } - if err = m.cacheInternalID(uid, id); err != nil { - log := appctx.GetLogger(ctx) - log.Error().Err(err).Msg("rest: error caching user details") + url = "" + if pagination, ok := result["pagination"].(map[string]interface{}); ok { + if links, ok := pagination["links"].(map[string]interface{}); ok { + if next, ok := links["next"].(string); ok { + url = fmt.Sprintf("%s%s", m.conf.APIBaseURL, next) + } + } } - return id, nil } - return internalID, nil + + return nil } func (m *manager) parseAndCacheUser(ctx context.Context, userData map[string]interface{}) (*userpb.User, error) { - id, ok := userData["id"].(string) - if !ok { - return nil, errors.New("parseAndCacheUser: Missing id in userData") - } upn, ok := userData["upn"].(string) if !ok { - return nil, errors.New("parseAndCacheUser: Missing upn in userData") + return nil, errors.New("rest: missing upn in user data") } mail, _ := userData["primaryAccountEmail"].(string) name, _ := userData["displayName"].(string) @@ -229,36 +217,15 @@ func (m *manager) parseAndCacheUser(ctx context.Context, userData map[string]int } if err := m.cacheUserDetails(u); err != nil { - log := appctx.GetLogger(ctx) log.Error().Err(err).Msg("rest: error caching user details") } - if err := m.cacheInternalID(userID, id); err != nil { - log := appctx.GetLogger(ctx) - log.Error().Err(err).Msg("rest: error caching internal ID") - } return u, nil } func (m *manager) GetUser(ctx context.Context, uid *userpb.UserId, skipFetchingGroups bool) (*userpb.User, error) { u, err := m.fetchCachedUserDetails(uid) if err != nil { - var ( - userData map[string]interface{} - err error - ) - if uid.Type == userpb.UserType_USER_TYPE_LIGHTWEIGHT { - // Lightweight accounts need to be fetched by email - userData, err = m.getLightweightUser(ctx, strings.TrimPrefix(uid.OpaqueId, "guest:")) - } else { - userData, err = m.getUserByParam(ctx, "upn", uid.OpaqueId) - } - if err != nil { - return nil, err - } - u, err = m.parseAndCacheUser(ctx, userData) - if err != nil { - return nil, err - } + return nil, err } if !skipFetchingGroups { @@ -273,34 +240,7 @@ func (m *manager) GetUser(ctx context.Context, uid *userpb.UserId, skipFetchingG } func (m *manager) GetUserByClaim(ctx context.Context, claim, value string, skipFetchingGroups bool) (*userpb.User, error) { - opaqueID, err := m.fetchCachedParam(claim, value) - if err == nil { - return m.GetUser(ctx, &userpb.UserId{OpaqueId: opaqueID}, skipFetchingGroups) - } - - switch claim { - case "mail": - claim = "primaryAccountEmail" - case "uid": - claim = "uid" - case "username": - claim = "upn" - default: - return nil, errors.New("rest: invalid field: " + claim) - } - - var userData map[string]interface{} - if claim == "upn" && strings.HasPrefix(value, "guest:") { - // Lightweight accounts need to be fetched by email, regardless of the demanded claim - userData, err = m.getLightweightUser(ctx, strings.TrimPrefix(value, "guest:")) - } else { - userData, err = m.getUserByParam(ctx, claim, value) - } - - if err != nil { - return nil, err - } - u, err := m.parseAndCacheUser(ctx, userData) + u, err := m.fetchCachedUserByParam(claim, value) if err != nil { return nil, err } @@ -316,61 +256,6 @@ func (m *manager) GetUserByClaim(ctx context.Context, claim, value string, skipF return u, nil } -func (m *manager) findUsersByFilter(ctx context.Context, url string, users map[string]*userpb.User, skipFetchingGroups bool) error { - - userData, err := m.apiTokenManager.SendAPIGetRequest(ctx, url, false) - if err != nil { - return err - } - - for _, usr := range userData { - usrInfo, ok := usr.(map[string]interface{}) - if !ok { - continue - } - - upn, ok := usrInfo["upn"].(string) - if !ok { - continue - } - mail, _ := usrInfo["primaryAccountEmail"].(string) - name, _ := usrInfo["displayName"].(string) - uidNumber, _ := usrInfo["uid"].(float64) - gidNumber, _ := usrInfo["gid"].(float64) - t, _ := usrInfo["type"].(string) - userType := getUserType(t, upn) - - if userType == userpb.UserType_USER_TYPE_APPLICATION { - continue - } - - uid := &userpb.UserId{ - OpaqueId: upn, - Idp: m.conf.IDProvider, - Type: userType, - } - var userGroups []string - if !skipFetchingGroups { - userGroups, err = m.GetUserGroups(ctx, uid) - if err != nil { - return err - } - } - - users[uid.OpaqueId] = &userpb.User{ - Id: uid, - Username: upn, - Mail: mail, - DisplayName: name, - UidNumber: int64(uidNumber), - GidNumber: int64(gidNumber), - Groups: userGroups, - } - } - - return nil -} - func (m *manager) FindUsers(ctx context.Context, query string, skipFetchingGroups bool) ([]*userpb.User, error) { // Look at namespaces filters. If the query starts with: @@ -386,25 +271,9 @@ func (m *manager) FindUsers(ctx context.Context, query string, skipFetchingGroup namespace, query = parts[0], parts[1] } - var filters []string - switch { - case usernameRegex.MatchString(query): - filters = []string{"upn", "displayName", "primaryAccountEmail"} - case emailRegex.MatchString(query): - filters = []string{"primaryAccountEmail"} - default: - return nil, errors.New("rest: illegal characters present in query: " + query) - } - - users := make(map[string]*userpb.User) - - for _, f := range filters { - url := fmt.Sprintf("%s/Identity/?filter=%s:contains:%s&field=id&field=upn&field=primaryAccountEmail&field=displayName&field=uid&field=gid&field=type", - m.conf.APIBaseURL, f, url.QueryEscape(query)) - err := m.findUsersByFilter(ctx, url, users, skipFetchingGroups) - if err != nil { - return nil, err - } + users, err := m.findCachedUsers(query) + if err != nil { + return nil, err } userSlice := []*userpb.User{} @@ -439,22 +308,18 @@ func isUserAnyType(user *userpb.User, types []userpb.UserType) bool { } func (m *manager) GetUserGroups(ctx context.Context, uid *userpb.UserId) ([]string, error) { - groups, err := m.fetchCachedUserGroups(uid) if err == nil { return groups, nil } - internalID, err := m.getInternalUserID(ctx, uid) - if err != nil { - return nil, err - } - url := fmt.Sprintf("%s/Identity/%s/groups?recursive=true", m.conf.APIBaseURL, internalID) - groupData, err := m.apiTokenManager.SendAPIGetRequest(ctx, url, false) + url := fmt.Sprintf("%s/api/v1.0/Identity/%s/groups?recursive=true", m.conf.APIBaseURL, uid.OpaqueId) + result, err := m.apiTokenManager.SendAPIGetRequest(ctx, url, false) if err != nil { return nil, err } + groupData := result["data"].([]interface{}) groups = []string{} for _, g := range groupData { @@ -490,13 +355,6 @@ func (m *manager) IsInGroup(ctx context.Context, uid *userpb.UserId, group strin return false, nil } -func extractUID(u *userpb.User) (string, error) { - if u.UidNumber == 0 { - return "", errors.New("rest: could not retrieve UID from user") - } - return strconv.FormatInt(u.UidNumber, 10), nil -} - func getUserType(userType, upn string) userpb.UserType { var t userpb.UserType switch userType { diff --git a/pkg/cbox/utils/tokenmanagement.go b/pkg/cbox/utils/tokenmanagement.go index e16fd503ba..c8978f4929 100644 --- a/pkg/cbox/utils/tokenmanagement.go +++ b/pkg/cbox/utils/tokenmanagement.go @@ -127,7 +127,7 @@ func (a *APITokenManager) getAPIToken(ctx context.Context) (string, time.Time, e } // SendAPIGetRequest makes an API GET Request to the passed URL -func (a *APITokenManager) SendAPIGetRequest(ctx context.Context, url string, forceRenewal bool) ([]interface{}, error) { +func (a *APITokenManager) SendAPIGetRequest(ctx context.Context, url string, forceRenewal bool) (map[string]interface{}, error) { err := a.renewAPIToken(ctx, forceRenewal) if err != nil { return nil, err @@ -168,10 +168,5 @@ func (a *APITokenManager) SendAPIGetRequest(ctx context.Context, url string, for return nil, err } - responseData, ok := result["data"].([]interface{}) - if !ok { - return nil, errors.New("rest: error in type assertion") - } - - return responseData, nil + return result, nil }