Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SCAN-165] Use Err Reporting #3862

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/sources/github/connector_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ type tokenConnector struct {
apiClient *github.Client
token string
isGitHubEnterprise bool
handleRateLimit func(context.Context, error) bool
handleRateLimit func(context.Context, error, ...errorReporter) bool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is errorReporter a vararg to allow for omission at the callsite? Are there places that we don't want to report the rate limit error?

user string
userMu sync.Mutex
}

var _ connector = (*tokenConnector)(nil)

func newTokenConnector(apiEndpoint string, token string, handleRateLimit func(context.Context, error) bool) (*tokenConnector, error) {
func newTokenConnector(apiEndpoint string, token string, handleRateLimit func(context.Context, error, ...errorReporter) bool) (*tokenConnector, error) {
const httpTimeoutSeconds = 60
httpClient := common.RetryableHTTPClientTimeout(int64(httpTimeoutSeconds))
tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token})
Expand Down
107 changes: 74 additions & 33 deletions pkg/sources/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e
// See: https://github.com/trufflesecurity/trufflehog/pull/2379#discussion_r1487454788
for _, name := range s.filteredRepoCache.Keys() {
url, _ := s.filteredRepoCache.Get(name)
url, err := s.ensureRepoInfoCache(ctx, url)
url, err := s.ensureRepoInfoCache(ctx, url, &unitErrorReporter{reporter})
if err != nil {
if err := dedupeReporter.UnitErr(ctx, err); err != nil {
return err
Expand Down Expand Up @@ -417,9 +417,12 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e
for _, repo := range s.filteredRepoCache.Values() {
ctx := context.WithValue(ctx, "repo", repo)

repo, err := s.ensureRepoInfoCache(ctx, repo)
repo, err := s.ensureRepoInfoCache(ctx, repo, &unitErrorReporter{reporter})
if err != nil {
ctx.Logger().Error(err, "error caching repo info")
if err := dedupeReporter.UnitErr(ctx, fmt.Errorf("error caching repo info: %w", err)); err != nil {
ctx.Logger().Error(err, "failed to report unit error")
}
Comment on lines +423 to +425
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An error returned by UnitErr does not mean it failed to record the error. Instead, it's a signal to stop processing and return. That was the intention, anyway, and in practice I think it really would only happen for a context cancellation, which sources should be context-aware anyway.

I'm thinking of removing the error return from those interfaces as it's kind of confusing and not really providing much value.

}
s.repos = append(s.repos, repo)
}
Expand All @@ -434,7 +437,7 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e
// provided repository URL. If not, it fetches and stores the metadata for the
// repository. In some cases, the gist URL needs to be normalized, which is
// returned by this function.
func (s *Source) ensureRepoInfoCache(ctx context.Context, repo string) (string, error) {
func (s *Source) ensureRepoInfoCache(ctx context.Context, repo string, reporter errorReporter) (string, error) {
if _, ok := s.repoInfoCache.get(repo); ok {
return repo, nil
}
Expand All @@ -450,23 +453,23 @@ func (s *Source) ensureRepoInfoCache(ctx context.Context, repo string) (string,
for {
gistID := extractGistID(urlParts)
gist, _, err := s.connector.APIClient().Gists.Get(ctx, gistID)
// Normalize the URL to the Gist's pull URL.
// See https://github.com/trufflesecurity/trufflehog/pull/2625#issuecomment-2025507937
repo = gist.GetGitPullURL()
if s.handleRateLimit(ctx, err) {
if s.handleRateLimit(ctx, err, reporter) {
continue
}
if err != nil {
return repo, fmt.Errorf("failed to fetch gist")
}
// Normalize the URL to the Gist's pull URL.
// See https://github.com/trufflesecurity/trufflehog/pull/2625#issuecomment-2025507937
repo = gist.GetGitPullURL()
Comment on lines +462 to +464
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this added here? It looks unrelated to error reporting.

s.cacheGistInfo(gist)
break
}
} else {
// Cache repository info.
for {
ghRepo, _, err := s.connector.APIClient().Repositories.Get(ctx, urlParts[1], urlParts[2])
if s.handleRateLimit(ctx, err) {
if s.handleRateLimit(ctx, err, reporter) {
continue
}
if err != nil {
Expand All @@ -491,7 +494,7 @@ func (s *Source) enumerateBasicAuth(ctx context.Context, reporter sources.UnitRe
// TODO: This modifies s.memberCache but it doesn't look like
// we do anything with it.
if userType == organization && s.conn.ScanUsers {
if err := s.addMembersByOrg(ctx, org); err != nil {
if err := s.addMembersByOrg(ctx, org, reporter); err != nil {
orgCtx.Logger().Error(err, "Unable to add members by org")
}
}
Expand Down Expand Up @@ -526,7 +529,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
var err error
for {
ghUser, _, err = s.connector.APIClient().Users.Get(ctx, "")
if s.handleRateLimit(ctx, err) {
if s.handleRateLimitWithUnitReporter(ctx, reporter, err) {
continue
}
if err != nil {
Expand All @@ -546,11 +549,11 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
}

if isGithubEnterprise {
s.addAllVisibleOrgs(ctx)
s.addAllVisibleOrgs(ctx, reporter)
} else {
// Scan for orgs is default with a token.
// GitHub App enumerates the repos that were assigned to it in GitHub App settings.
s.addOrgsByUser(ctx, ghUser.GetLogin())
s.addOrgsByUser(ctx, ghUser.GetLogin(), reporter)
}
}

Expand All @@ -564,7 +567,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
}

if userType == organization && s.conn.ScanUsers {
if err := s.addMembersByOrg(ctx, org); err != nil {
if err := s.addMembersByOrg(ctx, org, reporter); err != nil {
orgCtx.Logger().Error(err, "Unable to add members for org")
}
}
Expand All @@ -588,7 +591,7 @@ func (s *Source) enumerateWithApp(ctx context.Context, installationClient *githu

// Check if we need to find user repos.
if s.conn.ScanUsers {
err := s.addMembersByApp(ctx, installationClient)
err := s.addMembersByApp(ctx, installationClient, reporter)
if err != nil {
return err
}
Expand Down Expand Up @@ -739,13 +742,37 @@ var (
rateLimitResumeTime time.Time
)

// handleRateLimit returns true if a rate limit was handled
// errorReporter is an interface that captures just the error reporting functionality
type errorReporter interface {
Err(ctx context.Context, err error) error
}

// wrapper to adapt UnitReporter to errorReporter
type unitErrorReporter struct {
reporter sources.UnitReporter
}

func (u unitErrorReporter) Err(ctx context.Context, err error) error {
return u.reporter.UnitErr(ctx, err)
}

// wrapper to adapt ChunkReporter to errorReporter
type chunkErrorReporter struct {
reporter sources.ChunkReporter
}

func (c chunkErrorReporter) Err(ctx context.Context, err error) error {
return c.reporter.ChunkErr(ctx, err)
}

// handleRateLimit handles GitHub API rate limiting with an optional error reporter.
// Returns true if a rate limit was handled.
//
// Unauthenticated users have a rate limit of 60 requests per hour.
// Authenticated users have a rate limit of 5,000 requests per hour,
// however, certain actions are subject to a stricter "secondary" limit.
// https://docs.github.com/en/rest/overview/rate-limits-for-the-rest-api
func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool {
func (s *Source) handleRateLimit(ctx context.Context, errIn error, reporter ...errorReporter) bool {
if errIn == nil {
return false
}
Expand All @@ -757,7 +784,6 @@ func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool {
var retryAfter time.Duration
if resumeTime.IsZero() || time.Now().After(resumeTime) {
rateLimitMu.Lock()

var (
now = time.Now()

Expand Down Expand Up @@ -785,6 +811,12 @@ func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool {
retryAfter = retryAfter + jitter
rateLimitResumeTime = now.Add(retryAfter)
ctx.Logger().Info(fmt.Sprintf("exceeded %s rate limit", limitType), "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339))
// Only report the error if a reporter was provided
if len(reporter) > 0 {
if err := reporter[0].Err(ctx, fmt.Errorf("exceeded %s rate limit", limitType)); err != nil {
ctx.Logger().Error(err, "failed to report rate limit error")
}
}
Comment on lines +814 to +819
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to go with the vararg approach, I'd suggest iterating over all the reporters so at least valid calls with multiple reporters aren't silently ignored. It has the added bonus of not having to check the length of the input too.

for _, reporter := range reporters {
    if err := reporter.Err(ctx, ...)
}

} else {
retryAfter = (5 * time.Minute) + jitter
rateLimitResumeTime = now.Add(retryAfter)
Expand All @@ -803,6 +835,16 @@ func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool {
return true
}

// handleRateLimitWithUnitReporter is a wrapper around handleRateLimit that includes unit reporting
func (s *Source) handleRateLimitWithUnitReporter(ctx context.Context, reporter sources.UnitReporter, errIn error) bool {
return s.handleRateLimit(ctx, errIn, &unitErrorReporter{reporter: reporter})
}

// handleRateLimitWithChunkReporter is a wrapper around handleRateLimit that includes chunk reporting
func (s *Source) handleRateLimitWithChunkReporter(ctx context.Context, reporter sources.ChunkReporter, errIn error) bool {
return s.handleRateLimit(ctx, errIn, &chunkErrorReporter{reporter: reporter})
}

func (s *Source) addReposForMembers(ctx context.Context, reporter sources.UnitReporter) {
ctx.Logger().Info("Fetching repos from members", "members", len(s.memberCache))
for member := range s.memberCache {
Expand All @@ -823,7 +865,7 @@ func (s *Source) addUserGistsToCache(ctx context.Context, user string, reporter

for {
gists, res, err := s.connector.APIClient().Gists.List(ctx, user, gistOpts)
if s.handleRateLimit(ctx, err) {
if s.handleRateLimitWithUnitReporter(ctx, reporter, err) {
continue
}
if err != nil {
Expand All @@ -847,7 +889,7 @@ func (s *Source) addUserGistsToCache(ctx context.Context, user string, reporter
return nil
}

func (s *Source) addMembersByApp(ctx context.Context, installationClient *github.Client) error {
func (s *Source) addMembersByApp(ctx context.Context, installationClient *github.Client, reporter sources.UnitReporter) error {
opts := &github.ListOptions{
PerPage: membersAppPagination,
}
Expand All @@ -862,15 +904,15 @@ func (s *Source) addMembersByApp(ctx context.Context, installationClient *github
if org.Account.GetType() != "Organization" {
continue
}
if err := s.addMembersByOrg(ctx, *org.Account.Login); err != nil {
if err := s.addMembersByOrg(ctx, *org.Account.Login, reporter); err != nil {
return err
}
}

return nil
}

func (s *Source) addAllVisibleOrgs(ctx context.Context) {
func (s *Source) addAllVisibleOrgs(ctx context.Context, reporter sources.UnitReporter) {
ctx.Logger().V(2).Info("enumerating all visible organizations on GHE")
// Enumeration on this endpoint does not use pages it uses a since ID.
// The endpoint will return organizations with an ID greater than the given since ID.
Expand All @@ -883,7 +925,7 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {
}
for {
orgs, _, err := s.connector.APIClient().Organizations.ListAll(ctx, orgOpts)
if s.handleRateLimit(ctx, err) {
if s.handleRateLimitWithUnitReporter(ctx, reporter, err) {
continue
}
if err != nil {
Expand Down Expand Up @@ -915,14 +957,14 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {
}
}

func (s *Source) addOrgsByUser(ctx context.Context, user string) {
func (s *Source) addOrgsByUser(ctx context.Context, user string, reporter sources.UnitReporter) {
orgOpts := &github.ListOptions{
PerPage: defaultPagination,
}
logger := ctx.Logger().WithValues("user", user)
for {
orgs, resp, err := s.connector.APIClient().Organizations.List(ctx, "", orgOpts)
if s.handleRateLimit(ctx, err) {
if s.handleRateLimitWithUnitReporter(ctx, reporter, err) {
continue
}
if err != nil {
Expand All @@ -944,7 +986,7 @@ func (s *Source) addOrgsByUser(ctx context.Context, user string) {
}
}

func (s *Source) addMembersByOrg(ctx context.Context, org string) error {
func (s *Source) addMembersByOrg(ctx context.Context, org string, reporter sources.UnitReporter) error {
opts := &github.ListMembersOptions{
PublicOnly: false,
ListOptions: github.ListOptions{
Expand All @@ -955,7 +997,7 @@ func (s *Source) addMembersByOrg(ctx context.Context, org string) error {
logger := ctx.Logger().WithValues("org", org)
for {
members, res, err := s.connector.APIClient().Organizations.ListMembers(ctx, org, opts)
if s.handleRateLimit(ctx, err) {
if s.handleRateLimitWithUnitReporter(ctx, reporter, err) {
continue
}
if err != nil {
Expand Down Expand Up @@ -1087,7 +1129,7 @@ func (s *Source) processGistComments(ctx context.Context, gistURL string, urlPar
}
for {
comments, _, err := s.connector.APIClient().Gists.ListComments(ctx, gistID, options)
if s.handleRateLimit(ctx, err) {
if s.handleRateLimitWithChunkReporter(ctx, reporter, err) {
continue
}
if err != nil {
Expand Down Expand Up @@ -1187,7 +1229,6 @@ func (s *Source) processRepoComments(ctx context.Context, repoInfo repoInfo, rep
}

return nil

}

func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, reporter sources.ChunkReporter) error {
Expand All @@ -1203,7 +1244,7 @@ func (s *Source) processIssues(ctx context.Context, repoInfo repoInfo, reporter

for {
issues, _, err := s.connector.APIClient().Issues.ListByRepo(ctx, repoInfo.owner, repoInfo.name, bodyTextsOpts)
if s.handleRateLimit(ctx, err) {
if s.handleRateLimitWithChunkReporter(ctx, reporter, err) {
continue
}

Expand Down Expand Up @@ -1272,7 +1313,7 @@ func (s *Source) processIssueComments(ctx context.Context, repoInfo repoInfo, re

for {
issueComments, _, err := s.connector.APIClient().Issues.ListComments(ctx, repoInfo.owner, repoInfo.name, allComments, issueOpts)
if s.handleRateLimit(ctx, err) {
if s.handleRateLimitWithChunkReporter(ctx, reporter, err) {
continue
}
if err != nil {
Expand Down Expand Up @@ -1340,7 +1381,7 @@ func (s *Source) processPRs(ctx context.Context, repoInfo repoInfo, reporter sou

for {
prs, _, err := s.connector.APIClient().PullRequests.List(ctx, repoInfo.owner, repoInfo.name, prOpts)
if s.handleRateLimit(ctx, err) {
if s.handleRateLimitWithChunkReporter(ctx, reporter, err) {
continue
}
if err != nil {
Expand Down Expand Up @@ -1372,7 +1413,7 @@ func (s *Source) processPRComments(ctx context.Context, repoInfo repoInfo, repor

for {
prComments, _, err := s.connector.APIClient().PullRequests.ListComments(ctx, repoInfo.owner, repoInfo.name, allComments, prOpts)
if s.handleRateLimit(ctx, err) {
if s.handleRateLimitWithChunkReporter(ctx, reporter, err) {
continue
}
if err != nil {
Expand Down Expand Up @@ -1528,7 +1569,7 @@ func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporte
ctx = context.WithValue(ctx, "repo", repoURL)
// ChunkUnit is not guaranteed to be called from Enumerate, so we must
// check and fetch the repoInfoCache for this repo.
repoURL, err := s.ensureRepoInfoCache(ctx, repoURL)
repoURL, err := s.ensureRepoInfoCache(ctx, repoURL, &chunkErrorReporter{reporter: reporter})
if err != nil {
return err
}
Expand Down
6 changes: 5 additions & 1 deletion pkg/sources/github/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ func (s *Source) getReposByOrgOrUser(ctx context.Context, name string, reporter
if err == nil {
return organization, nil
} else if !isGitHub404Error(err) {
if err := reporter.UnitErr(ctx, fmt.Errorf("error getting repos by org: %w", err)); err != nil {
return unknown, err
}
return unknown, err
}

Expand Down Expand Up @@ -181,6 +184,7 @@ func isGitHub404Error(err error) bool {
return ghErr.Response.StatusCode == http.StatusNotFound
}

// processRepos is the main function for getting repositories from a source.
func (s *Source) processRepos(ctx context.Context, target string, reporter sources.UnitReporter, listRepos repoLister, listOpts repoListOptions) error {
logger := ctx.Logger().WithValues("target", target)
opts := listOpts.getListOptions()
Expand All @@ -192,7 +196,7 @@ func (s *Source) processRepos(ctx context.Context, target string, reporter sourc

for {
someRepos, res, err := listRepos(ctx, target, listOpts)
if s.handleRateLimit(ctx, err) {
if s.handleRateLimitWithUnitReporter(ctx, reporter, err) {
continue
}
if err != nil {
Expand Down
3 changes: 1 addition & 2 deletions pkg/sources/source_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,7 @@ func (s *SourceManager) enumerate(ctx context.Context, source Source, report *Jo
// Check if source units are supported and configured.
canUseSourceUnits := s.useSourceUnitsFunc != nil
if enumChunker, ok := source.(SourceUnitEnumerator); ok && canUseSourceUnits && s.useSourceUnitsFunc() {
ctx.Logger().Info("running source",
"with_units", true)
ctx.Logger().Info("running source", "with_units", true)
return s.enumerateWithUnits(ctx, enumChunker, report, reporter)
}
return fmt.Errorf("Enumeration not supported or configured for source: %s", source.Type().String())
Expand Down
Loading