diff --git a/.golangci.yml b/.golangci.yml index 4b9ad0d385..da78e94a0c 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -12,6 +12,9 @@ issues: - linters: [ staticcheck ] path: ".*_test\\.go$" text: "Subjects has been deprecated since Go 1\\.18.*Subjects will not include the system roots" + - linters: [ gosec ] + path: "(.*_test\\.go$)|(^test/.*)" + text: "integer overflow conversion" linters: enable: @@ -45,6 +48,3 @@ linters-settings: rules: - name: unused-parameter disabled: true - gosec: - excludes: - - G115 # "Potential integer overflow when converting between integer types"; TODO re-enable eventually diff --git a/cmd/spire-server/cli/entry/create.go b/cmd/spire-server/cli/entry/create.go index 2ef733b075..62caede981 100644 --- a/cmd/spire-server/cli/entry/create.go +++ b/cmd/spire-server/cli/entry/create.go @@ -4,7 +4,9 @@ import ( "context" "errors" "flag" + "fmt" + "github.com/ccoveille/go-safecast" "github.com/mitchellh/cli" entryv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/entry/v1" "github.com/spiffe/spire-api-sdk/proto/spire/api/types" @@ -175,6 +177,16 @@ func (c *createCommand) parseConfig() ([]*types.Entry, error) { return nil, err } + x509SvidTTL, err := safecast.ToInt32(c.x509SVIDTTL) + if err != nil { + return nil, fmt.Errorf("X509 SVID TTL: %w", err) + } + + jwtSvidTTL, err := safecast.ToInt32(c.jwtSVIDTTL) + if err != nil { + return nil, fmt.Errorf("JWT SVID TTL: %w", err) + } + e := &types.Entry{ Id: c.entryID, ParentId: parentID, @@ -183,8 +195,8 @@ func (c *createCommand) parseConfig() ([]*types.Entry, error) { ExpiresAt: c.entryExpiry, DnsNames: c.dnsNames, StoreSvid: c.storeSVID, - X509SvidTtl: int32(c.x509SVIDTTL), - JwtSvidTtl: int32(c.jwtSVIDTTL), + X509SvidTtl: x509SvidTTL, + JwtSvidTtl: jwtSvidTTL, Hint: c.hint, } @@ -254,7 +266,7 @@ func prettyPrintCreate(env *commoncli.Env, results ...any) error { for _, r := range failed { env.ErrPrintf("Failed to create the following entry (code: %s, msg: %q):\n", - codes.Code(r.Status.Code), + codes.Code(safecast.MustConvert[uint32](r.Status.Code)), r.Status.Message) printEntry(r.Entry, env.ErrPrintf) } diff --git a/cmd/spire-server/cli/entry/delete.go b/cmd/spire-server/cli/entry/delete.go index f3f27cd6bd..bcf3ee34d6 100644 --- a/cmd/spire-server/cli/entry/delete.go +++ b/cmd/spire-server/cli/entry/delete.go @@ -8,6 +8,7 @@ import ( "io" "os" + "github.com/ccoveille/go-safecast" "github.com/mitchellh/cli" entryv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/entry/v1" "github.com/spiffe/spire/cmd/spire-server/util" @@ -135,7 +136,7 @@ func (c *deleteCommand) prettyPrintDelete(env *commoncli.Env, results ...any) er for _, result := range failed { env.ErrPrintf("Failed to delete entry with ID %s (code: %s, msg: %q)\n", result.Id, - codes.Code(result.Status.Code), + codes.Code(safecast.MustConvert[uint32](result.Status.Code)), result.Status.Message) } diff --git a/cmd/spire-server/cli/entry/update.go b/cmd/spire-server/cli/entry/update.go index a6121c6f99..dcd5530c86 100644 --- a/cmd/spire-server/cli/entry/update.go +++ b/cmd/spire-server/cli/entry/update.go @@ -4,7 +4,9 @@ import ( "context" "errors" "flag" + "fmt" + "github.com/ccoveille/go-safecast" "github.com/mitchellh/cli" entryv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/entry/v1" "github.com/spiffe/spire-api-sdk/proto/spire/api/types" @@ -169,6 +171,16 @@ func (c *updateCommand) parseConfig() ([]*types.Entry, error) { return nil, err } + x509SvidTTL, err := safecast.ToInt32(c.x509SvidTTL) + if err != nil { + return nil, fmt.Errorf("X509 SVID TTL: %w", err) + } + + jwtSvidTTL, err := safecast.ToInt32(c.jwtSvidTTL) + if err != nil { + return nil, fmt.Errorf("JWT SVID TTL: %w", err) + } + e := &types.Entry{ Id: c.entryID, ParentId: parentID, @@ -176,8 +188,8 @@ func (c *updateCommand) parseConfig() ([]*types.Entry, error) { Downstream: c.downstream, ExpiresAt: c.entryExpiry, DnsNames: c.dnsNames, - X509SvidTtl: int32(c.x509SvidTTL), - JwtSvidTtl: int32(c.jwtSvidTTL), + X509SvidTtl: x509SvidTTL, + JwtSvidTtl: jwtSvidTTL, Hint: c.hint, } @@ -240,7 +252,7 @@ func prettyPrintUpdate(env *commoncli.Env, results ...any) error { // Print entries that failed to be updated for _, r := range failed { env.ErrPrintf("Failed to update the following entry (code: %s, msg: %q):\n", - codes.Code(r.Status.Code), + codes.Code(safecast.MustConvert[uint32](r.Status.Code)), r.Status.Message) printEntry(r.Entry, env.ErrPrintf) } diff --git a/cmd/spire-server/cli/federation/create.go b/cmd/spire-server/cli/federation/create.go index 65082cb03e..7ce13f88b5 100644 --- a/cmd/spire-server/cli/federation/create.go +++ b/cmd/spire-server/cli/federation/create.go @@ -6,6 +6,7 @@ import ( "flag" "fmt" + "github.com/ccoveille/go-safecast" "github.com/mitchellh/cli" trustdomainv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/trustdomain/v1" "github.com/spiffe/spire-api-sdk/proto/spire/api/types" @@ -101,7 +102,7 @@ func (c *createCommand) prettyPrintCreate(env *commoncli.Env, results ...any) er for _, r := range failed { env.Println() env.ErrPrintf("Failed to create the following federation relationship (code: %s, msg: %q):\n", - codes.Code(r.Status.Code), + codes.Code(safecast.MustConvert[uint32](r.Status.Code)), r.Status.Message) printFederationRelationship(r.FederationRelationship, env.ErrPrintf) } diff --git a/cmd/spire-server/cli/federation/update.go b/cmd/spire-server/cli/federation/update.go index c1c51afa14..ef78cf3df3 100644 --- a/cmd/spire-server/cli/federation/update.go +++ b/cmd/spire-server/cli/federation/update.go @@ -6,6 +6,7 @@ import ( "flag" "fmt" + "github.com/ccoveille/go-safecast" "github.com/mitchellh/cli" trustdomainv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/trustdomain/v1" "github.com/spiffe/spire-api-sdk/proto/spire/api/types" @@ -97,7 +98,7 @@ func (c *updateCommand) prettyPrintUpdate(env *commoncli.Env, results ...any) er for _, r := range failed { env.Println() env.ErrPrintf("Failed to update the following federation relationship (code: %s, msg: %q):\n", - codes.Code(r.Status.Code), + codes.Code(safecast.MustConvert[uint32](r.Status.Code)), r.Status.Message) printFederationRelationship(r.FederationRelationship, env.ErrPrintf) } diff --git a/cmd/spire-server/cli/jwt/mint.go b/cmd/spire-server/cli/jwt/mint.go index a3165069a6..a9e00438fd 100644 --- a/cmd/spire-server/cli/jwt/mint.go +++ b/cmd/spire-server/cli/jwt/mint.go @@ -7,6 +7,7 @@ import ( "fmt" "time" + "github.com/ccoveille/go-safecast" "github.com/go-jose/go-jose/v4/jwt" "github.com/mitchellh/cli" "github.com/spiffe/go-spiffe/v2/spiffeid" @@ -63,6 +64,10 @@ func (c *mintCommand) Run(ctx context.Context, env *commoncli.Env, serverClient if err != nil { return err } + ttl, err := ttlToSeconds(c.ttl) + if err != nil { + return fmt.Errorf("TTL: %w", err) + } client := serverClient.NewSVIDClient() resp, err := client.MintJWTSVID(ctx, &svidv1.MintJWTSVIDRequest{ @@ -70,7 +75,7 @@ func (c *mintCommand) Run(ctx context.Context, env *commoncli.Env, serverClient TrustDomain: spiffeID.TrustDomain().Name(), Path: spiffeID.Path(), }, - Ttl: ttlToSeconds(c.ttl), + Ttl: ttl, Audience: c.audience, }) if err != nil { @@ -132,8 +137,8 @@ func getJWTSVIDEndOfLife(token string) (time.Time, error) { // ttlToSeconds returns the number of seconds in a duration, rounded up to // the nearest second -func ttlToSeconds(ttl time.Duration) int32 { - return int32((ttl + time.Second - 1) / time.Second) +func ttlToSeconds(ttl time.Duration) (int32, error) { + return safecast.ToInt32((ttl + time.Second - 1) / time.Second) } func prettyPrintMint(env *commoncli.Env, results ...any) error { diff --git a/cmd/spire-server/cli/token/generate.go b/cmd/spire-server/cli/token/generate.go index 7410f05597..1dbc79f134 100644 --- a/cmd/spire-server/cli/token/generate.go +++ b/cmd/spire-server/cli/token/generate.go @@ -3,7 +3,9 @@ package token import ( "context" "flag" + "fmt" + "github.com/ccoveille/go-safecast" "github.com/mitchellh/cli" "github.com/spiffe/go-spiffe/v2/spiffeid" agentv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/agent/v1" @@ -44,11 +46,15 @@ func (g *generateCommand) Run(ctx context.Context, _ *commoncli.Env, serverClien if err != nil { return err } + ttl, err := safecast.ToInt32(g.TTL) + if err != nil { + return fmt.Errorf("TTL: %w", err) + } c := serverClient.NewAgentClient() resp, err := c.CreateJoinToken(ctx, &agentv1.CreateJoinTokenRequest{ AgentId: id, - Ttl: int32(g.TTL), + Ttl: ttl, }) if err != nil { return err diff --git a/cmd/spire-server/cli/x509/mint.go b/cmd/spire-server/cli/x509/mint.go index 801b2d0f0a..b2f67cb37b 100644 --- a/cmd/spire-server/cli/x509/mint.go +++ b/cmd/spire-server/cli/x509/mint.go @@ -15,6 +15,7 @@ import ( "net/url" "time" + "github.com/ccoveille/go-safecast" "github.com/mitchellh/cli" "github.com/spiffe/go-spiffe/v2/spiffeid" bundlev1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/bundle/v1" @@ -80,6 +81,11 @@ func (c *mintCommand) Run(ctx context.Context, env *commoncli.Env, serverClient return err } + ttl, err := ttlToSeconds(c.ttl) + if err != nil { + return fmt.Errorf("TTL: %w", err) + } + key, err := c.generateKey() if err != nil { return fmt.Errorf("unable to generate key: %w", err) @@ -96,7 +102,7 @@ func (c *mintCommand) Run(ctx context.Context, env *commoncli.Env, serverClient client := serverClient.NewSVIDClient() resp, err := client.MintX509SVID(ctx, &svidv1.MintX509SVIDRequest{ Csr: csr, - Ttl: ttlToSeconds(c.ttl), + Ttl: ttl, }) if err != nil { return fmt.Errorf("unable to mint SVID: %w", err) @@ -167,8 +173,8 @@ func (c *mintCommand) Run(ctx context.Context, env *commoncli.Env, serverClient // ttlToSeconds returns the number of seconds in a duration, rounded up to // the nearest second -func ttlToSeconds(ttl time.Duration) int32 { - return int32((ttl + time.Second - 1) / time.Second) +func ttlToSeconds(ttl time.Duration) (int32, error) { + return safecast.ToInt32((ttl + time.Second - 1) / time.Second) } type mintResult struct { diff --git a/go.mod b/go.mod index e7ac478861..873ed413de 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/sts v1.33.8 github.com/aws/smithy-go v1.22.1 github.com/blang/semver/v4 v4.0.0 + github.com/ccoveille/go-safecast v1.5.0 github.com/cenkalti/backoff/v4 v4.3.0 github.com/docker/docker v27.5.0+incompatible github.com/envoyproxy/go-control-plane/envoy v1.32.3 diff --git a/go.sum b/go.sum index 0d58583144..2915db57fa 100644 --- a/go.sum +++ b/go.sum @@ -643,6 +643,8 @@ github.com/buildkite/roko v1.2.0/go.mod h1:23R9e6nHxgedznkwwfmqZ6+0VJZJZ2Sg/uVcp github.com/bytecodealliance/wasmtime-go/v3 v3.0.2 h1:3uZCA/BLTIu+DqCfguByNMJa2HVHpXvjfy0Dy7g6fuA= github.com/bytecodealliance/wasmtime-go/v3 v3.0.2/go.mod h1:RnUjnIXxEJcL6BgCvNyzCCRzZcxCgsZCi+RNlvYor5Q= github.com/cactus/go-statsd-client/v5 v5.0.0/go.mod h1:COEvJ1E+/E2L4q6QE5CkjWPi4eeDw9maJBMIuMPBZbY= +github.com/ccoveille/go-safecast v1.5.0 h1:cT/3uVQ/i5PTiJvhvkSU81HeKNurtyQtBndXEH3hDg4= +github.com/ccoveille/go-safecast v1.5.0/go.mod h1:QqwNjxQ7DAqY0C721OIO9InMk9zCwcsO7tnRuHytad8= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= diff --git a/pkg/agent/api/debug/v1/service.go b/pkg/agent/api/debug/v1/service.go index 4a6987bd7a..b6a2edb9ac 100644 --- a/pkg/agent/api/debug/v1/service.go +++ b/pkg/agent/api/debug/v1/service.go @@ -3,9 +3,11 @@ package debug import ( "context" "crypto/x509" + "fmt" "sync" "time" + "github.com/ccoveille/go-safecast" "github.com/sirupsen/logrus" "github.com/spiffe/go-spiffe/v2/bundle/x509bundle" "github.com/spiffe/go-spiffe/v2/spiffeid" @@ -92,15 +94,32 @@ func (s *Service) GetInfo(context.Context, *debugv1.GetInfoRequest) (*debugv1.Ge }) } + uptime, err := safecast.ToInt32(s.uptime().Seconds()) + if err != nil { + return nil, fmt.Errorf("uptime: %w", err) + } + x509SvidsCount, err := safecast.ToInt32(s.m.CountX509SVIDs()) + if err != nil { + return nil, fmt.Errorf("X.509 SVIDs count: %w", err) + } + jwtSvidsCount, err := safecast.ToInt32(s.m.CountJWTSVIDs()) + if err != nil { + return nil, fmt.Errorf("JWT SVIDs count: %w", err) + } + svidstoreX509SvidsCount, err := safecast.ToInt32(s.m.CountSVIDStoreX509SVIDs()) + if err != nil { + return nil, fmt.Errorf("SVIDStore X.509 SVIDs count: %w", err) + } + // Reset clock and set current response s.getInfoResp.ts = s.clock.Now() s.getInfoResp.resp = &debugv1.GetInfoResponse{ SvidChain: svidChain, - Uptime: int32(s.uptime().Seconds()), - SvidsCount: int32(s.m.CountX509SVIDs()), - CachedX509SvidsCount: int32(s.m.CountX509SVIDs()), - CachedJwtSvidsCount: int32(s.m.CountJWTSVIDs()), - CachedSvidstoreX509SvidsCount: int32(s.m.CountSVIDStoreX509SVIDs()), + Uptime: uptime, + SvidsCount: x509SvidsCount, + CachedX509SvidsCount: x509SvidsCount, + CachedJwtSvidsCount: jwtSvidsCount, + CachedSvidstoreX509SvidsCount: svidstoreX509SvidsCount, LastSyncSuccess: s.m.GetLastSync().UTC().Unix(), } } diff --git a/pkg/agent/plugin/keymanager/base/keymanagerbase.go b/pkg/agent/plugin/keymanager/base/keymanagerbase.go index cd54e170f1..7e10393152 100644 --- a/pkg/agent/plugin/keymanager/base/keymanagerbase.go +++ b/pkg/agent/plugin/keymanager/base/keymanagerbase.go @@ -14,6 +14,7 @@ import ( "sort" "sync" + "github.com/ccoveille/go-safecast" keymanagerv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/agent/keymanager/v1" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -170,7 +171,7 @@ func (m *Base) signData(req *keymanagerv1.SignDataRequest) (*keymanagerv1.SignDa if opts.HashAlgorithm == keymanagerv1.HashAlgorithm_UNSPECIFIED_HASH_ALGORITHM { return nil, status.Error(codes.InvalidArgument, "hash algorithm is required") } - signerOpts = crypto.Hash(opts.HashAlgorithm) + signerOpts = crypto.Hash(safecast.MustConvert[uint](opts.HashAlgorithm)) case *keymanagerv1.SignDataRequest_PssOptions: if opts.PssOptions == nil { return nil, status.Error(codes.InvalidArgument, "PSS options are nil") @@ -180,7 +181,7 @@ func (m *Base) signData(req *keymanagerv1.SignDataRequest) (*keymanagerv1.SignDa } signerOpts = &rsa.PSSOptions{ SaltLength: int(opts.PssOptions.SaltLength), - Hash: crypto.Hash(opts.PssOptions.HashAlgorithm), + Hash: crypto.Hash(safecast.MustConvert[uint](opts.PssOptions.HashAlgorithm)), } default: return nil, status.Errorf(codes.InvalidArgument, "unsupported signer opts type %T", opts) diff --git a/pkg/agent/plugin/keymanager/v1.go b/pkg/agent/plugin/keymanager/v1.go index 00c9b08ee8..75e1611eae 100644 --- a/pkg/agent/plugin/keymanager/v1.go +++ b/pkg/agent/plugin/keymanager/v1.go @@ -7,6 +7,7 @@ import ( "crypto/x509" "io" + "github.com/ccoveille/go-safecast" keymanagerv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/agent/keymanager/v1" "github.com/spiffe/spire/pkg/common/plugin" "google.golang.org/grpc/codes" @@ -118,7 +119,7 @@ func (v1 *V1) convertKeyType(t KeyType) (keymanagerv1.KeyType, error) { func (v1 *V1) convertHashAlgorithm(h crypto.Hash) keymanagerv1.HashAlgorithm { // Hash algorithm constants are aligned. - return keymanagerv1.HashAlgorithm(h) + return keymanagerv1.HashAlgorithm(safecast.MustConvert[int32](h)) } type v1Key struct { @@ -155,7 +156,7 @@ func (s *v1Key) signContext(ctx context.Context, digest []byte, opts crypto.Sign case *rsa.PSSOptions: req.SignerOpts = &keymanagerv1.SignDataRequest_PssOptions{ PssOptions: &keymanagerv1.SignDataRequest_PSSOptions{ - SaltLength: int32(opts.SaltLength), + SaltLength: safecast.MustConvert[int32](opts.SaltLength), HashAlgorithm: s.v1.convertHashAlgorithm(opts.Hash), }, } diff --git a/pkg/agent/plugin/workloadattestor/docker/docker_posix.go b/pkg/agent/plugin/workloadattestor/docker/docker_posix.go index 9c2bdacaeb..335f087e77 100644 --- a/pkg/agent/plugin/workloadattestor/docker/docker_posix.go +++ b/pkg/agent/plugin/workloadattestor/docker/docker_posix.go @@ -92,7 +92,7 @@ func (h *containerHelper) getContainerID(pID int32, log hclog.Logger) (string, e } extractor := containerinfo.Extractor{RootDir: h.rootDir, VerboseLogging: h.verboseContainerLocatorLogs} - return extractor.GetContainerID(int(pID), log) + return extractor.GetContainerID(pID, log) } func getDockerHost(c *dockerPluginConfig) string { diff --git a/pkg/agent/plugin/workloadattestor/k8s/k8s_posix.go b/pkg/agent/plugin/workloadattestor/k8s/k8s_posix.go index c29fdf8304..35510de46c 100644 --- a/pkg/agent/plugin/workloadattestor/k8s/k8s_posix.go +++ b/pkg/agent/plugin/workloadattestor/k8s/k8s_posix.go @@ -65,7 +65,7 @@ func (h *containerHelper) GetPodUIDAndContainerID(pID int32, log hclog.Logger) ( } extractor := containerinfo.Extractor{RootDir: h.rootDir, VerboseLogging: h.verboseContainerLocatorLogs} - return extractor.GetPodUIDAndContainerID(int(pID), log) + return extractor.GetPodUIDAndContainerID(pID, log) } func getPodUIDAndContainerIDFromCGroups(cgroups []cgroups.Cgroup) (types.UID, string, error) { diff --git a/pkg/agent/plugin/workloadattestor/systemd/systemd_posix.go b/pkg/agent/plugin/workloadattestor/systemd/systemd_posix.go index 44c276bd09..851db2f7f1 100644 --- a/pkg/agent/plugin/workloadattestor/systemd/systemd_posix.go +++ b/pkg/agent/plugin/workloadattestor/systemd/systemd_posix.go @@ -7,6 +7,7 @@ import ( "fmt" "sync" + "github.com/ccoveille/go-safecast" "github.com/godbus/dbus/v5" "github.com/hashicorp/go-hclog" workloadattestorv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/agent/workloadattestor/v1" @@ -55,7 +56,11 @@ func (p *Plugin) SetLogger(log hclog.Logger) { } func (p *Plugin) Attest(ctx context.Context, req *workloadattestorv1.AttestRequest) (*workloadattestorv1.AttestResponse, error) { - uInfo, err := p.getUnitInfo(ctx, p, uint(req.Pid)) + pid, err := safecast.ToUint(req.Pid) + if err != nil { + return nil, fmt.Errorf("PID: %w", err) + } + uInfo, err := p.getUnitInfo(ctx, p, pid) if err != nil { return nil, err } diff --git a/pkg/agent/plugin/workloadattestor/v1.go b/pkg/agent/plugin/workloadattestor/v1.go index ebcebb9847..22adeb6e53 100644 --- a/pkg/agent/plugin/workloadattestor/v1.go +++ b/pkg/agent/plugin/workloadattestor/v1.go @@ -2,7 +2,9 @@ package workloadattestor import ( "context" + "fmt" + "github.com/ccoveille/go-safecast" workloadattestorv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/agent/workloadattestor/v1" "github.com/spiffe/spire/pkg/common/plugin" "github.com/spiffe/spire/proto/spire/common" @@ -14,8 +16,12 @@ type V1 struct { } func (v1 *V1) Attest(ctx context.Context, pid int) ([]*common.Selector, error) { + pidInt32, err := safecast.ToInt32(pid) + if err != nil { + return nil, v1.WrapErr(fmt.Errorf("PID: %w", err)) + } resp, err := v1.WorkloadAttestorPluginClient.Attest(ctx, &workloadattestorv1.AttestRequest{ - Pid: int32(pid), + Pid: pidInt32, }) if err != nil { return nil, v1.WrapErr(err) diff --git a/pkg/common/containerinfo/extract.go b/pkg/common/containerinfo/extract.go index a60a3930cb..ebc7118a8d 100644 --- a/pkg/common/containerinfo/extract.go +++ b/pkg/common/containerinfo/extract.go @@ -40,16 +40,16 @@ type Extractor struct { VerboseLogging bool } -func (e *Extractor) GetContainerID(pid int, log hclog.Logger) (string, error) { +func (e *Extractor) GetContainerID(pid int32, log hclog.Logger) (string, error) { _, containerID, err := e.extractInfo(pid, log, false) return containerID, err } -func (e *Extractor) GetPodUIDAndContainerID(pid int, log hclog.Logger) (types.UID, string, error) { +func (e *Extractor) GetPodUIDAndContainerID(pid int32, log hclog.Logger) (types.UID, string, error) { return e.extractInfo(pid, log, true) } -func (e *Extractor) extractInfo(pid int, log hclog.Logger, extractPodUID bool) (types.UID, string, error) { +func (e *Extractor) extractInfo(pid int32, log hclog.Logger, extractPodUID bool) (types.UID, string, error) { // Try to get the information from /proc/pid/mountinfo first. Otherwise, // fall back to /proc/pid/cgroup. If it isn't in mountinfo, then the // workload being attested likely originates in the same Pod as the agent. @@ -74,7 +74,7 @@ func (e *Extractor) extractInfo(pid int, log hclog.Logger, extractPodUID bool) ( return podUID, containerID, nil } -func (e *Extractor) extractPodUIDAndContainerIDFromMountInfo(pid int, log hclog.Logger, extractPodUID bool) (types.UID, string, error) { +func (e *Extractor) extractPodUIDAndContainerIDFromMountInfo(pid int32, log hclog.Logger, extractPodUID bool) (types.UID, string, error) { mountInfoPath := filepath.Join(e.RootDir, "/proc", fmt.Sprint(pid), "mountinfo") mountInfos, err := mount.ParseMountInfo(mountInfoPath) @@ -122,8 +122,8 @@ func (e *Extractor) extractPodUIDAndContainerIDFromMountInfo(pid int, log hclog. return ex.PodUID(), ex.ContainerID(), nil } -func (e *Extractor) extractPodUIDAndContainerIDFromCGroups(pid int, log hclog.Logger, extractPodUID bool) (types.UID, string, error) { - cgroups, err := cgroups.GetCgroups(int32(pid), dirFS(e.RootDir)) +func (e *Extractor) extractPodUIDAndContainerIDFromCGroups(pid int32, log hclog.Logger, extractPodUID bool) (types.UID, string, error) { + cgroups, err := cgroups.GetCgroups(pid, dirFS(e.RootDir)) if err != nil { if errors.Is(err, fs.ErrNotExist) { return "", "", nil diff --git a/pkg/common/selector/set.go b/pkg/common/selector/set.go index 656692ceb1..c3ce8c0de0 100644 --- a/pkg/common/selector/set.go +++ b/pkg/common/selector/set.go @@ -9,7 +9,6 @@ import ( type Set interface { Raw() []*common.Selector Array() []*Selector - Power() <-chan Set Equal(otherSet Set) bool Includes(selector *Selector) bool IncludesSet(s2 Set) bool @@ -64,10 +63,6 @@ func (s *set) Array() []*Selector { return c } -func (s *set) Power() <-chan Set { - return PowerSet(s) -} - func (s *set) Equal(otherSet Set) bool { return EqualSet(s, otherSet.(*set)) } diff --git a/pkg/common/selector/set_utils.go b/pkg/common/selector/set_utils.go index 4a64f944d2..f43fd65611 100644 --- a/pkg/common/selector/set_utils.go +++ b/pkg/common/selector/set_utils.go @@ -1,27 +1,5 @@ package selector -import ( - "math" - "strconv" - "strings" -) - -// PowerSet implements a range-able combination generator. It takes a set, and -// returns a channel over which all possible combinations of selectors are eventually -// returned. It is meant to aid in the discovery of applicable cache entries, given the -// superset of selectors discovered during attestation. -func PowerSet(selectors *set) <-chan Set { - c := make(chan Set) - - go func() { - powerSet(selectors, c) - - close(c) - }() - - return c -} - // EqualSet determines whether two sets of selectors are equal or not func EqualSet(a, b *set) bool { if a.Size() != b.Size() { @@ -55,33 +33,3 @@ func IncludesSet(s1, s2 *set) bool { } return true } - -// powerSet, given a set of selectors, returns every possible combination -// of selector subsets. -// -// https://en.wikipedia.org/wiki/Power_set -func powerSet(s *set, c chan Set) { - sarr := s.Array() - powSetSize := math.Pow(2, float64(len(*s))) - - // Skip the empty set by starting the counter at 1 - for i := 1; i < int(powSetSize); i++ { - set := &set{} - - // Form binary representation of the counter - binaryString := strconv.FormatUint(uint64(i), 2) - binary := strings.Split(binaryString, "") - - // Walk through the binary, and append - // "enabled" elements to the working set - for position := range binary { - // Read the binary right to left - negPosition := len(binary) - position - 1 - if binary[negPosition] == "1" { - set.Add(sarr[position]) - } - } - - c <- set - } -} diff --git a/pkg/common/selector/set_utils_test.go b/pkg/common/selector/set_utils_test.go index 696facaac9..410a1a1cd8 100644 --- a/pkg/common/selector/set_utils_test.go +++ b/pkg/common/selector/set_utils_test.go @@ -9,9 +9,6 @@ import ( var ( selector1 = &Selector{Type: "foo", Value: "bar"} selector2 = &Selector{Type: "bar", Value: "bat"} - selector3 = &Selector{Type: "bat", Value: "baz"} - selector4 = &Selector{Type: "baz", Value: "quz"} - selector5 = &Selector{Type: "quz", Value: "foo"} ) func TestEqualSet(t *testing.T) { @@ -24,67 +21,3 @@ func TestEqualSet(t *testing.T) { set2.Remove(selector1) a.True(!set1.Equal(set2)) } - -func TestPowerSet(t *testing.T) { - a := assert.New(t) - - selectorSet := NewSet( - selector1, - selector2, - selector3, - selector4, - selector5, - ) - - expectedResults := []Set{ - NewSet(selector1), - NewSet(selector2), - NewSet(selector1, selector2), - NewSet(selector3), - NewSet(selector1, selector3), - NewSet(selector2, selector3), - NewSet(selector1, selector2, selector3), - NewSet(selector4), - NewSet(selector1, selector4), - NewSet(selector2, selector4), - NewSet(selector1, selector2, selector4), - NewSet(selector3, selector4), - NewSet(selector1, selector3, selector4), - NewSet(selector2, selector3, selector4), - NewSet(selector1, selector2, selector3, selector4), - NewSet(selector5), - NewSet(selector1, selector5), - NewSet(selector2, selector5), - NewSet(selector1, selector2, selector5), - NewSet(selector3, selector5), - NewSet(selector1, selector3, selector5), - NewSet(selector2, selector3, selector5), - NewSet(selector1, selector2, selector3, selector5), - NewSet(selector4, selector5), - NewSet(selector1, selector4, selector5), - NewSet(selector2, selector4, selector5), - NewSet(selector1, selector2, selector4, selector5), - NewSet(selector3, selector4, selector5), - NewSet(selector1, selector3, selector4, selector5), - NewSet(selector2, selector3, selector4, selector5), - NewSet(selector1, selector2, selector3, selector4, selector5), - } - - var results []Set - for result := range PowerSet(selectorSet.(*set)) { - results = append(results, result) - } - - if a.Equal(len(expectedResults), len(results)) { - for _, resultSet := range results { - var isIncluded bool - for _, expectedSet := range expectedResults { - if expectedSet.Equal(resultSet) { - isIncluded = true - break - } - } - a.True(isIncluded) - } - } -} diff --git a/pkg/server/api/audit/audit.go b/pkg/server/api/audit/audit.go index 5ef806434d..8d9fe47008 100644 --- a/pkg/server/api/audit/audit.go +++ b/pkg/server/api/audit/audit.go @@ -1,6 +1,7 @@ package audit import ( + "github.com/ccoveille/go-safecast" "github.com/sirupsen/logrus" "github.com/spiffe/spire-api-sdk/proto/spire/api/types" "github.com/spiffe/spire/pkg/common/telemetry" @@ -61,7 +62,7 @@ func (l *logger) AuditWithTypesStatus(fields logrus.Fields, s *types.Status) { } func fieldsFromStatus(s *types.Status) logrus.Fields { - err := status.Error(codes.Code(s.Code), s.Message) + err := status.Error(codes.Code(safecast.MustConvert[uint32](s.Code)), s.Message) return fieldsFromError(err) } diff --git a/pkg/server/api/status.go b/pkg/server/api/status.go index 6f5ff94ff3..a24a2d18c9 100644 --- a/pkg/server/api/status.go +++ b/pkg/server/api/status.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" + "github.com/ccoveille/go-safecast" "github.com/sirupsen/logrus" "github.com/spiffe/spire-api-sdk/proto/spire/api/types" "google.golang.org/grpc/codes" @@ -13,7 +14,7 @@ import ( // CreateStatus creates a proto Status func CreateStatus(code codes.Code, msg string) *types.Status { return &types.Status{ - Code: int32(code), + Code: safecast.MustConvert[int32](code), Message: msg, } } @@ -21,7 +22,7 @@ func CreateStatus(code codes.Code, msg string) *types.Status { // CreateStatus creates a proto Status func CreateStatusf(code codes.Code, format string, a ...any) *types.Status { return &types.Status{ - Code: int32(code), + Code: safecast.MustConvert[int32](code), Message: fmt.Sprintf(format, a...), } } diff --git a/pkg/server/datastore/sqlstore/migration.go b/pkg/server/datastore/sqlstore/migration.go index 0d8eece2c7..79a4f039f3 100644 --- a/pkg/server/datastore/sqlstore/migration.go +++ b/pkg/server/datastore/sqlstore/migration.go @@ -3,7 +3,6 @@ package sqlstore import ( "errors" "fmt" - "math" "strconv" "github.com/blang/semver/v4" @@ -389,12 +388,12 @@ func getDBCodeVersion(migration Migration) (dbCodeVersion semver.Version, err er } func isCompatibleCodeVersion(thisCodeVersion, dbCodeVersion semver.Version) bool { - // If major version is the same and minor version is +/- 1, versions are - // compatible - if dbCodeVersion.Major != thisCodeVersion.Major || (math.Abs(float64(int64(dbCodeVersion.Minor)-int64(thisCodeVersion.Minor))) > 1) { - return false + // If major version is the same and minor version is +/- 1, versions are compatible + minMinor, maxMinor := min(dbCodeVersion.Minor, thisCodeVersion.Minor), max(dbCodeVersion.Minor, thisCodeVersion.Minor) + if dbCodeVersion.Major == thisCodeVersion.Major && (minMinor == maxMinor || minMinor+1 == maxMinor) { + return true } - return true + return false } func initDB(db *gorm.DB, dbType string, log logrus.FieldLogger) (err error) { diff --git a/pkg/server/datastore/sqlstore/sqlstore.go b/pkg/server/datastore/sqlstore/sqlstore.go index 45ed51340f..aaccd5ddca 100644 --- a/pkg/server/datastore/sqlstore/sqlstore.go +++ b/pkg/server/datastore/sqlstore/sqlstore.go @@ -14,6 +14,7 @@ import ( "time" "unicode" + "github.com/ccoveille/go-safecast" "github.com/gofrs/uuid/v5" "github.com/hashicorp/hcl" "github.com/hashicorp/hcl/hcl/ast" @@ -1303,7 +1304,7 @@ func countBundles(tx *gorm.DB) (int32, error) { return 0, newWrappedSQLError(err) } - return int32(count), nil + return safecast.ToInt32(count) } // listBundles can be used to fetch all existing bundles. @@ -1589,7 +1590,7 @@ func countAttestedNodes(tx *gorm.DB) (int32, error) { return 0, newWrappedSQLError(err) } - return int32(count), nil + return safecast.ToInt32(count) } func countAttestedNodesHasFilters(req *datastore.CountAttestedNodesRequest) bool { @@ -1688,7 +1689,7 @@ func countAttestedNodesWithFilters(ctx context.Context, db *sqlDB, _ logrus.Fiel } } - val += int32(len(resp.Nodes)) + val += safecast.MustConvert[int32](len(resp.Nodes)) listReq.Pagination = resp.Pagination } @@ -3327,7 +3328,7 @@ func countRegistrationEntries(ctx context.Context, db *sqlDB, _ logrus.FieldLogg } } - val += int32(len(resp.Entries)) + val += safecast.MustConvert[int32](len(resp.Entries)) listReq.Pagination = resp.Pagination } @@ -3852,10 +3853,16 @@ func fillEntryFromRow(entry *common.RegistrationEntry, r *entryRow) error { entry.FederatesWith = append(entry.FederatesWith, r.TrustDomain.String) } if r.RegTTL.Valid { - entry.X509SvidTtl = int32(r.RegTTL.Int64) + var err error + if entry.X509SvidTtl, err = safecast.ToInt32(r.RegTTL.Int64); err != nil { + return newSQLError("RegTTL: %s", err) + } } if r.RegJwtSvidTTL.Valid { - entry.JwtSvidTtl = int32(r.RegJwtSvidTTL.Int64) + var err error + if entry.JwtSvidTtl, err = safecast.ToInt32(r.RegJwtSvidTTL.Int64); err != nil { + return newSQLError("RegJwtSvidTTL: %s", err) + } } if r.Hint.Valid { entry.Hint = r.Hint.String diff --git a/pkg/server/endpoints/authorized_entryfetcher_attested_nodes.go b/pkg/server/endpoints/authorized_entryfetcher_attested_nodes.go index a514f67d89..39e0bbd011 100644 --- a/pkg/server/endpoints/authorized_entryfetcher_attested_nodes.go +++ b/pkg/server/endpoints/authorized_entryfetcher_attested_nodes.go @@ -239,8 +239,8 @@ func (a *attestedNodes) updateCachedNodes(ctx context.Context) error { } func (a *attestedNodes) emitMetrics() { - if a.skippedNodeEvents != int(a.eventTracker.EventCount()) { - a.skippedNodeEvents = int(a.eventTracker.EventCount()) + if a.skippedNodeEvents != a.eventTracker.EventCount() { + a.skippedNodeEvents = a.eventTracker.EventCount() server_telemetry.SetSkippedNodeEventIDsCacheCountGauge(a.metrics, a.skippedNodeEvents) } diff --git a/pkg/server/endpoints/authorized_entryfetcher_registration_entries.go b/pkg/server/endpoints/authorized_entryfetcher_registration_entries.go index cc32536cef..74fcd443ee 100644 --- a/pkg/server/endpoints/authorized_entryfetcher_registration_entries.go +++ b/pkg/server/endpoints/authorized_entryfetcher_registration_entries.go @@ -251,8 +251,8 @@ func (a *registrationEntries) updateCachedEntries(ctx context.Context) error { } func (a *registrationEntries) emitMetrics() { - if a.skippedEntryEvents != int(a.eventTracker.EventCount()) { - a.skippedEntryEvents = int(a.eventTracker.EventCount()) + if a.skippedEntryEvents != a.eventTracker.EventCount() { + a.skippedEntryEvents = a.eventTracker.EventCount() server_telemetry.SetSkippedEntryEventIDsCacheCountGauge(a.metrics, a.skippedEntryEvents) } diff --git a/pkg/server/endpoints/eventTracker.go b/pkg/server/endpoints/eventTracker.go index dcaf493b8c..a19cfda6bb 100644 --- a/pkg/server/endpoints/eventTracker.go +++ b/pkg/server/endpoints/eventTracker.go @@ -3,6 +3,8 @@ package endpoints import ( "sync" "time" + + "github.com/ccoveille/go-safecast" ) type eventTracker struct { @@ -20,7 +22,7 @@ func PollPeriods(pollTime time.Duration, trackTime time.Duration) uint { if trackTime < time.Second { trackTime = time.Second } - return uint(1 + (trackTime-1)/pollTime) + return safecast.MustConvert[uint](1 + (trackTime-1)/pollTime) } func NewEventTracker(pollPeriods uint) *eventTracker { @@ -74,6 +76,6 @@ func (et *eventTracker) FreeEvents(events []uint) { et.pool.Put(&events) } -func (et *eventTracker) EventCount() uint { - return uint(len(et.events)) +func (et *eventTracker) EventCount() int { + return len(et.events) } diff --git a/pkg/server/plugin/keymanager/base/keymanagerbase.go b/pkg/server/plugin/keymanager/base/keymanagerbase.go index 00a4d7cabe..53c5c754ee 100644 --- a/pkg/server/plugin/keymanager/base/keymanagerbase.go +++ b/pkg/server/plugin/keymanager/base/keymanagerbase.go @@ -14,6 +14,7 @@ import ( "sort" "sync" + "github.com/ccoveille/go-safecast" keymanagerv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/server/keymanager/v1" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -170,7 +171,7 @@ func (m *Base) signData(req *keymanagerv1.SignDataRequest) (*keymanagerv1.SignDa if opts.HashAlgorithm == keymanagerv1.HashAlgorithm_UNSPECIFIED_HASH_ALGORITHM { return nil, status.Error(codes.InvalidArgument, "hash algorithm is required") } - signerOpts = crypto.Hash(opts.HashAlgorithm) + signerOpts = crypto.Hash(safecast.MustConvert[uint](opts.HashAlgorithm)) case *keymanagerv1.SignDataRequest_PssOptions: if opts.PssOptions == nil { return nil, status.Error(codes.InvalidArgument, "PSS options are nil") @@ -180,7 +181,7 @@ func (m *Base) signData(req *keymanagerv1.SignDataRequest) (*keymanagerv1.SignDa } signerOpts = &rsa.PSSOptions{ SaltLength: int(opts.PssOptions.SaltLength), - Hash: crypto.Hash(opts.PssOptions.HashAlgorithm), + Hash: crypto.Hash(safecast.MustConvert[uint](opts.PssOptions.HashAlgorithm)), } default: return nil, status.Errorf(codes.InvalidArgument, "unsupported signer opts type %T", opts) diff --git a/pkg/server/plugin/keymanager/v1.go b/pkg/server/plugin/keymanager/v1.go index 2cb522691c..8f8680f56c 100644 --- a/pkg/server/plugin/keymanager/v1.go +++ b/pkg/server/plugin/keymanager/v1.go @@ -7,6 +7,7 @@ import ( "crypto/x509" "io" + "github.com/ccoveille/go-safecast" keymanagerv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/server/keymanager/v1" "github.com/spiffe/spire/pkg/common/plugin" "google.golang.org/grpc/codes" @@ -118,7 +119,7 @@ func (v1 *V1) convertKeyType(t KeyType) (keymanagerv1.KeyType, error) { func (v1 *V1) convertHashAlgorithm(h crypto.Hash) keymanagerv1.HashAlgorithm { // Hash algorithm constants are aligned. - return keymanagerv1.HashAlgorithm(h) + return keymanagerv1.HashAlgorithm(safecast.MustConvert[int32](h)) } type v1Key struct { @@ -155,7 +156,7 @@ func (s *v1Key) signContext(ctx context.Context, digest []byte, opts crypto.Sign case *rsa.PSSOptions: req.SignerOpts = &keymanagerv1.SignDataRequest_PssOptions{ PssOptions: &keymanagerv1.SignDataRequest_PSSOptions{ - SaltLength: int32(opts.SaltLength), + SaltLength: safecast.MustConvert[int32](opts.SaltLength), HashAlgorithm: s.v1.convertHashAlgorithm(opts.Hash), }, } diff --git a/pkg/server/plugin/upstreamauthority/v1.go b/pkg/server/plugin/upstreamauthority/v1.go index 63f1330806..6b26729c8d 100644 --- a/pkg/server/plugin/upstreamauthority/v1.go +++ b/pkg/server/plugin/upstreamauthority/v1.go @@ -7,6 +7,7 @@ import ( "io" "time" + "github.com/ccoveille/go-safecast" upstreamauthorityv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/server/upstreamauthority/v1" "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/types" "github.com/spiffe/spire/pkg/common/coretypes/jwtkey" @@ -35,7 +36,7 @@ func (v1 *V1) MintX509CA(ctx context.Context, csr []byte, preferredTTL time.Dura stream, err := v1.UpstreamAuthorityPluginClient.MintX509CAAndSubscribe(ctx, &upstreamauthorityv1.MintX509CARequest{ Csr: csr, - PreferredTtl: int32(preferredTTL / time.Second), + PreferredTtl: safecast.MustConvert[int32](preferredTTL / time.Second), }) if err != nil { return nil, nil, nil, v1.WrapErr(err) diff --git a/test/clock/clock.go b/test/clock/clock.go index d1a53892d4..6c89d7162c 100644 --- a/test/clock/clock.go +++ b/test/clock/clock.go @@ -85,13 +85,13 @@ func (m *Mock) WaitForTicker(timeout time.Duration, format string, args ...any) m.WaitForTickerMulti(timeout, 1, format, args...) } -func (m *Mock) WaitForTickerMulti(timeout time.Duration, count int, format string, args ...any) { +func (m *Mock) WaitForTickerMulti(timeout time.Duration, count int32, format string, args ...any) { deadlineChan := time.After(timeout) for { select { case <-m.tickerC: - if m.tickerCount.Load() >= int32(count) { - m.tickerCount.Add(-1 * int32(count)) + if m.tickerCount.Load() >= count { + m.tickerCount.Add(-1 * count) return } case <-deadlineChan: