diff --git a/pkg/sources/github/connector_token.go b/pkg/sources/github/connector_token.go index 88ef8f8f1ec0..469f50b98c2f 100644 --- a/pkg/sources/github/connector_token.go +++ b/pkg/sources/github/connector_token.go @@ -17,14 +17,14 @@ type tokenConnector struct { apiClient *github.Client token string isGitHubEnterprise bool - handleRateLimit func(context.Context, error) bool + handleRateLimit func(context.Context, error, ...errorReporter) bool user string userMu sync.Mutex } var _ connector = (*tokenConnector)(nil) -func newTokenConnector(apiEndpoint string, token string, handleRateLimit func(context.Context, error) bool) (*tokenConnector, error) { +func newTokenConnector(apiEndpoint string, token string, handleRateLimit func(context.Context, error, ...errorReporter) bool) (*tokenConnector, error) { const httpTimeoutSeconds = 60 httpClient := common.RetryableHTTPClientTimeout(int64(httpTimeoutSeconds)) tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token}) diff --git a/pkg/sources/github/github.go b/pkg/sources/github/github.go index 4dce437fba47..d3aca1ccbd90 100644 --- a/pkg/sources/github/github.go +++ b/pkg/sources/github/github.go @@ -381,7 +381,7 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e // See: https://github.com/trufflesecurity/trufflehog/pull/2379#discussion_r1487454788 for _, name := range s.filteredRepoCache.Keys() { url, _ := s.filteredRepoCache.Get(name) - url, err := s.ensureRepoInfoCache(ctx, url) + url, err := s.ensureRepoInfoCache(ctx, url, &unitErrorReporter{reporter}) if err != nil { if err := dedupeReporter.UnitErr(ctx, err); err != nil { return err @@ -417,9 +417,12 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e for _, repo := range s.filteredRepoCache.Values() { ctx := context.WithValue(ctx, "repo", repo) - repo, err := s.ensureRepoInfoCache(ctx, repo) + repo, err := s.ensureRepoInfoCache(ctx, repo, &unitErrorReporter{reporter}) if err != nil { ctx.Logger().Error(err, "error caching repo info") + if err := dedupeReporter.UnitErr(ctx, fmt.Errorf("error caching repo info: %w", err)); err != nil { + ctx.Logger().Error(err, "failed to report unit error") + } } s.repos = append(s.repos, repo) } @@ -434,7 +437,7 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e // provided repository URL. If not, it fetches and stores the metadata for the // repository. In some cases, the gist URL needs to be normalized, which is // returned by this function. -func (s *Source) ensureRepoInfoCache(ctx context.Context, repo string) (string, error) { +func (s *Source) ensureRepoInfoCache(ctx context.Context, repo string, reporter errorReporter) (string, error) { if _, ok := s.repoInfoCache.get(repo); ok { return repo, nil } @@ -450,15 +453,15 @@ func (s *Source) ensureRepoInfoCache(ctx context.Context, repo string) (string, for { gistID := extractGistID(urlParts) gist, _, err := s.connector.APIClient().Gists.Get(ctx, gistID) - // Normalize the URL to the Gist's pull URL. - // See https://github.com/trufflesecurity/trufflehog/pull/2625#issuecomment-2025507937 - repo = gist.GetGitPullURL() - if s.handleRateLimit(ctx, err) { + if s.handleRateLimit(ctx, err, reporter) { continue } if err != nil { return repo, fmt.Errorf("failed to fetch gist") } + // Normalize the URL to the Gist's pull URL. + // See https://github.com/trufflesecurity/trufflehog/pull/2625#issuecomment-2025507937 + repo = gist.GetGitPullURL() s.cacheGistInfo(gist) break } @@ -466,7 +469,7 @@ func (s *Source) ensureRepoInfoCache(ctx context.Context, repo string) (string, // Cache repository info. for { ghRepo, _, err := s.connector.APIClient().Repositories.Get(ctx, urlParts[1], urlParts[2]) - if s.handleRateLimit(ctx, err) { + if s.handleRateLimit(ctx, err, reporter) { continue } if err != nil { @@ -491,7 +494,7 @@ func (s *Source) enumerateBasicAuth(ctx context.Context, reporter sources.UnitRe // TODO: This modifies s.memberCache but it doesn't look like // we do anything with it. if userType == organization && s.conn.ScanUsers { - if err := s.addMembersByOrg(ctx, org); err != nil { + if err := s.addMembersByOrg(ctx, org, reporter); err != nil { orgCtx.Logger().Error(err, "Unable to add members by org") } } @@ -526,7 +529,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool var err error for { ghUser, _, err = s.connector.APIClient().Users.Get(ctx, "") - if s.handleRateLimit(ctx, err) { + if s.handleRateLimitWithUnitReporter(ctx, reporter, err) { continue } if err != nil { @@ -546,11 +549,11 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool } if isGithubEnterprise { - s.addAllVisibleOrgs(ctx) + s.addAllVisibleOrgs(ctx, reporter) } else { // Scan for orgs is default with a token. // GitHub App enumerates the repos that were assigned to it in GitHub App settings. - s.addOrgsByUser(ctx, ghUser.GetLogin()) + s.addOrgsByUser(ctx, ghUser.GetLogin(), reporter) } } @@ -564,7 +567,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool } if userType == organization && s.conn.ScanUsers { - if err := s.addMembersByOrg(ctx, org); err != nil { + if err := s.addMembersByOrg(ctx, org, reporter); err != nil { orgCtx.Logger().Error(err, "Unable to add members for org") } } @@ -588,7 +591,7 @@ func (s *Source) enumerateWithApp(ctx context.Context, installationClient *githu // Check if we need to find user repos. if s.conn.ScanUsers { - err := s.addMembersByApp(ctx, installationClient) + err := s.addMembersByApp(ctx, installationClient, reporter) if err != nil { return err } @@ -739,13 +742,37 @@ var ( rateLimitResumeTime time.Time ) -// handleRateLimit returns true if a rate limit was handled +// errorReporter is an interface that captures just the error reporting functionality +type errorReporter interface { + Err(ctx context.Context, err error) error +} + +// wrapper to adapt UnitReporter to errorReporter +type unitErrorReporter struct { + reporter sources.UnitReporter +} + +func (u unitErrorReporter) Err(ctx context.Context, err error) error { + return u.reporter.UnitErr(ctx, err) +} + +// wrapper to adapt ChunkReporter to errorReporter +type chunkErrorReporter struct { + reporter sources.ChunkReporter +} + +func (c chunkErrorReporter) Err(ctx context.Context, err error) error { + return c.reporter.ChunkErr(ctx, err) +} + +// handleRateLimit handles GitHub API rate limiting with an optional error reporter. +// Returns true if a rate limit was handled. // // Unauthenticated users have a rate limit of 60 requests per hour. // Authenticated users have a rate limit of 5,000 requests per hour, // however, certain actions are subject to a stricter "secondary" limit. // https://docs.github.com/en/rest/overview/rate-limits-for-the-rest-api -func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool { +func (s *Source) handleRateLimit(ctx context.Context, errIn error, reporter ...errorReporter) bool { if errIn == nil { return false } @@ -757,7 +784,6 @@ func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool { var retryAfter time.Duration if resumeTime.IsZero() || time.Now().After(resumeTime) { rateLimitMu.Lock() - var ( now = time.Now() @@ -785,6 +811,12 @@ func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool { retryAfter = retryAfter + jitter rateLimitResumeTime = now.Add(retryAfter) ctx.Logger().Info(fmt.Sprintf("exceeded %s rate limit", limitType), "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339)) + // Only report the error if a reporter was provided + if len(reporter) > 0 { + if err := reporter[0].Err(ctx, fmt.Errorf("exceeded %s rate limit", limitType)); err != nil { + ctx.Logger().Error(err, "failed to report rate limit error") + } + } } else { retryAfter = (5 * time.Minute) + jitter rateLimitResumeTime = now.Add(retryAfter) @@ -803,6 +835,16 @@ func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool { return true } +// handleRateLimitWithUnitReporter is a wrapper around handleRateLimit that includes unit reporting +func (s *Source) handleRateLimitWithUnitReporter(ctx context.Context, reporter sources.UnitReporter, errIn error) bool { + return s.handleRateLimit(ctx, errIn, &unitErrorReporter{reporter: reporter}) +} + +// handleRateLimitWithChunkReporter is a wrapper around handleRateLimit that includes chunk reporting +func (s *Source) handleRateLimitWithChunkReporter(ctx context.Context, reporter sources.ChunkReporter, errIn error) bool { + return s.handleRateLimit(ctx, errIn, &chunkErrorReporter{reporter: reporter}) +} + func (s *Source) addReposForMembers(ctx context.Context, reporter sources.UnitReporter) { ctx.Logger().Info("Fetching repos from members", "members", len(s.memberCache)) for member := range s.memberCache { @@ -823,7 +865,7 @@ func (s *Source) addUserGistsToCache(ctx context.Context, user string, reporter for { gists, res, err := s.connector.APIClient().Gists.List(ctx, user, gistOpts) - if s.handleRateLimit(ctx, err) { + if s.handleRateLimitWithUnitReporter(ctx, reporter, err) { continue } if err != nil { @@ -847,7 +889,7 @@ func (s *Source) addUserGistsToCache(ctx context.Context, user string, reporter return nil } -func (s *Source) addMembersByApp(ctx context.Context, installationClient *github.Client) error { +func (s *Source) addMembersByApp(ctx context.Context, installationClient *github.Client, reporter sources.UnitReporter) error { opts := &github.ListOptions{ PerPage: membersAppPagination, } @@ -862,7 +904,7 @@ func (s *Source) addMembersByApp(ctx context.Context, installationClient *github if org.Account.GetType() != "Organization" { continue } - if err := s.addMembersByOrg(ctx, *org.Account.Login); err != nil { + if err := s.addMembersByOrg(ctx, *org.Account.Login, reporter); err != nil { return err } } @@ -870,7 +912,7 @@ func (s *Source) addMembersByApp(ctx context.Context, installationClient *github return nil } -func (s *Source) addAllVisibleOrgs(ctx context.Context) { +func (s *Source) addAllVisibleOrgs(ctx context.Context, reporter sources.UnitReporter) { ctx.Logger().V(2).Info("enumerating all visible organizations on GHE") // Enumeration on this endpoint does not use pages it uses a since ID. // The endpoint will return organizations with an ID greater than the given since ID. @@ -883,7 +925,7 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) { } for { orgs, _, err := s.connector.APIClient().Organizations.ListAll(ctx, orgOpts) - if s.handleRateLimit(ctx, err) { + if s.handleRateLimitWithUnitReporter(ctx, reporter, err) { continue } if err != nil { @@ -915,14 +957,14 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) { } } -func (s *Source) addOrgsByUser(ctx context.Context, user string) { +func (s *Source) addOrgsByUser(ctx context.Context, user string, reporter sources.UnitReporter) { orgOpts := &github.ListOptions{ PerPage: defaultPagination, } logger := ctx.Logger().WithValues("user", user) for { orgs, resp, err := s.connector.APIClient().Organizations.List(ctx, "", orgOpts) - if s.handleRateLimit(ctx, err) { + if s.handleRateLimitWithUnitReporter(ctx, reporter, err) { continue } if err != nil { @@ -944,7 +986,7 @@ func (s *Source) addOrgsByUser(ctx context.Context, user string) { } } -func (s *Source) addMembersByOrg(ctx context.Context, org string) error { +func (s *Source) addMembersByOrg(ctx context.Context, org string, reporter sources.UnitReporter) error { opts := &github.ListMembersOptions{ PublicOnly: false, ListOptions: github.ListOptions{ @@ -955,7 +997,7 @@ func (s *Source) addMembersByOrg(ctx context.Context, org string) error { logger := ctx.Logger().WithValues("org", org) for { members, res, err := s.connector.APIClient().Organizations.ListMembers(ctx, org, opts) - if s.handleRateLimit(ctx, err) { + if s.handleRateLimitWithUnitReporter(ctx, reporter, err) { continue } if err != nil { @@ -1087,7 +1129,7 @@ func (s *Source) processGistComments(ctx context.Context, gistURL string, urlPar } for { comments, _, err := s.connector.APIClient().Gists.ListComments(ctx, gistID, options) - if s.handleRateLimit(ctx, err) { + if s.handleRateLimitWithChunkReporter(ctx, reporter, err) { continue } if err != nil { @@ -1187,7 +1229,6 @@ func (s *Source) processRepoComments(ctx context.Context, repoInfo repoInfo, rep } return nil - } func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, reporter sources.ChunkReporter) error { @@ -1203,7 +1244,7 @@ func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, reporter for { issues, _, err := s.connector.APIClient().Issues.ListByRepo(ctx, repoInfo.owner, repoInfo.name, bodyTextsOpts) - if s.handleRateLimit(ctx, err) { + if s.handleRateLimitWithChunkReporter(ctx, reporter, err) { continue } @@ -1272,7 +1313,7 @@ func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, re for { issueComments, _, err := s.connector.APIClient().Issues.ListComments(ctx, repoInfo.owner, repoInfo.name, allComments, issueOpts) - if s.handleRateLimit(ctx, err) { + if s.handleRateLimitWithChunkReporter(ctx, reporter, err) { continue } if err != nil { @@ -1340,7 +1381,7 @@ func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, reporter sou for { prs, _, err := s.connector.APIClient().PullRequests.List(ctx, repoInfo.owner, repoInfo.name, prOpts) - if s.handleRateLimit(ctx, err) { + if s.handleRateLimitWithChunkReporter(ctx, reporter, err) { continue } if err != nil { @@ -1372,7 +1413,7 @@ func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, repor for { prComments, _, err := s.connector.APIClient().PullRequests.ListComments(ctx, repoInfo.owner, repoInfo.name, allComments, prOpts) - if s.handleRateLimit(ctx, err) { + if s.handleRateLimitWithChunkReporter(ctx, reporter, err) { continue } if err != nil { @@ -1528,7 +1569,7 @@ func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporte ctx = context.WithValue(ctx, "repo", repoURL) // ChunkUnit is not guaranteed to be called from Enumerate, so we must // check and fetch the repoInfoCache for this repo. - repoURL, err := s.ensureRepoInfoCache(ctx, repoURL) + repoURL, err := s.ensureRepoInfoCache(ctx, repoURL, &chunkErrorReporter{reporter: reporter}) if err != nil { return err } diff --git a/pkg/sources/github/repo.go b/pkg/sources/github/repo.go index c5e813578e7c..660dfd95551e 100644 --- a/pkg/sources/github/repo.go +++ b/pkg/sources/github/repo.go @@ -154,6 +154,9 @@ func (s *Source) getReposByOrgOrUser(ctx context.Context, name string, reporter if err == nil { return organization, nil } else if !isGitHub404Error(err) { + if err := reporter.UnitErr(ctx, fmt.Errorf("error getting repos by org: %w", err)); err != nil { + return unknown, err + } return unknown, err } @@ -181,6 +184,7 @@ func isGitHub404Error(err error) bool { return ghErr.Response.StatusCode == http.StatusNotFound } +// processRepos is the main function for getting repositories from a source. func (s *Source) processRepos(ctx context.Context, target string, reporter sources.UnitReporter, listRepos repoLister, listOpts repoListOptions) error { logger := ctx.Logger().WithValues("target", target) opts := listOpts.getListOptions() @@ -192,7 +196,7 @@ func (s *Source) processRepos(ctx context.Context, target string, reporter sourc for { someRepos, res, err := listRepos(ctx, target, listOpts) - if s.handleRateLimit(ctx, err) { + if s.handleRateLimitWithUnitReporter(ctx, reporter, err) { continue } if err != nil { diff --git a/pkg/sources/source_manager.go b/pkg/sources/source_manager.go index 43784141f4a8..2e03be1d8127 100644 --- a/pkg/sources/source_manager.go +++ b/pkg/sources/source_manager.go @@ -406,8 +406,7 @@ func (s *SourceManager) enumerate(ctx context.Context, source Source, report *Jo // Check if source units are supported and configured. canUseSourceUnits := s.useSourceUnitsFunc != nil if enumChunker, ok := source.(SourceUnitEnumerator); ok && canUseSourceUnits && s.useSourceUnitsFunc() { - ctx.Logger().Info("running source", - "with_units", true) + ctx.Logger().Info("running source", "with_units", true) return s.enumerateWithUnits(ctx, enumChunker, report, reporter) } return fmt.Errorf("Enumeration not supported or configured for source: %s", source.Type().String())