Skip to content

Commit

Permalink
Refactor tctl InitFunc to return an authclient.ClientI
Browse files Browse the repository at this point in the history
Returning the interface from the InitFunc allows tests to better
mock the auth client, which permits tests to be run without an
actual teleport process being launched.

The ClientI interface was also extended with additional methods
that tctl was already relying on.
  • Loading branch information
rosstimothy committed Jan 16, 2025
1 parent 60ff7d4 commit 1d542fc
Show file tree
Hide file tree
Showing 44 changed files with 272 additions and 240 deletions.
32 changes: 32 additions & 0 deletions lib/auth/authclient/clt.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package authclient

import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
Expand All @@ -41,8 +42,10 @@ import (
"github.com/gravitational/teleport/api/client/usertask"
apidefaults "github.com/gravitational/teleport/api/defaults"
accessgraphsecretsv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessgraph/v1"
autoupdatev1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/autoupdate/v1"
clusterconfigpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/clusterconfig/v1"
dbobjectimportrulev1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/dbobjectimportrule/v1"
decisionv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/decision/v1alpha1"
devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1"
identitycenterv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/identitycenter/v1"
integrationv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/integration/v1"
Expand All @@ -56,6 +59,7 @@ import (
trustpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/trust/v1"
userspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/users/v1"
"github.com/gravitational/teleport/api/gen/proto/go/teleport/vnet/v1"
workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
userpreferencesv1 "github.com/gravitational/teleport/api/gen/proto/go/userpreferences/v1"
"github.com/gravitational/teleport/api/mfa"
"github.com/gravitational/teleport/api/types"
Expand Down Expand Up @@ -1826,6 +1830,9 @@ type ClientI interface {
// when calling this method, but all RPCs will return "not implemented" errors
// (as per the default gRPC behavior).
WorkloadIdentityServiceClient() machineidv1pb.WorkloadIdentityServiceClient
SPIFFEFederationServiceClient() machineidv1pb.SPIFFEFederationServiceClient
WorkloadIdentityResourceServiceClient() workloadidentityv1pb.WorkloadIdentityResourceServiceClient
WorkloadIdentityIssuanceClient() workloadidentityv1pb.WorkloadIdentityIssuanceServiceClient

// NotificationServiceClient returns a notification service client.
// Clients connecting to older Teleport versions, still get a client
Expand Down Expand Up @@ -1903,4 +1910,29 @@ type ClientI interface {

// GitServerReadOnlyClient returns the read-only client for Git servers.
GitServerReadOnlyClient() gitserver.ReadOnlyClient

DecisionClient() decisionv1.DecisionServiceClient

SetMFAPromptConstructor(pc mfa.PromptConstructor)

CreateAutoUpdateConfig(ctx context.Context, config *autoupdatev1pb.AutoUpdateConfig) (*autoupdatev1pb.AutoUpdateConfig, error)
UpdateAutoUpdateConfig(ctx context.Context, config *autoupdatev1pb.AutoUpdateConfig) (*autoupdatev1pb.AutoUpdateConfig, error)
UpsertAutoUpdateConfig(ctx context.Context, config *autoupdatev1pb.AutoUpdateConfig) (*autoupdatev1pb.AutoUpdateConfig, error)
DeleteAutoUpdateConfig(ctx context.Context) error

CreateAutoUpdateVersion(ctx context.Context, config *autoupdatev1pb.AutoUpdateVersion) (*autoupdatev1pb.AutoUpdateVersion, error)
UpdateAutoUpdateVersion(ctx context.Context, config *autoupdatev1pb.AutoUpdateVersion) (*autoupdatev1pb.AutoUpdateVersion, error)
UpsertAutoUpdateVersion(ctx context.Context, config *autoupdatev1pb.AutoUpdateVersion) (*autoupdatev1pb.AutoUpdateVersion, error)
DeleteAutoUpdateVersion(ctx context.Context) error

CreateAutoUpdateAgentRollout(ctx context.Context, config *autoupdatev1pb.AutoUpdateAgentRollout) (*autoupdatev1pb.AutoUpdateAgentRollout, error)
UpdateAutoUpdateAgentRollout(ctx context.Context, config *autoupdatev1pb.AutoUpdateAgentRollout) (*autoupdatev1pb.AutoUpdateAgentRollout, error)
UpsertAutoUpdateAgentRollout(ctx context.Context, config *autoupdatev1pb.AutoUpdateAgentRollout) (*autoupdatev1pb.AutoUpdateAgentRollout, error)
DeleteAutoUpdateAgentRollout(cxt context.Context) error

GetDesktopBootstrapScript(ctx context.Context) (string, error)

CrownJewelsClient() services.CrownJewels
UserTasksClient() services.UserTasks
Config() *tls.Config
}
22 changes: 11 additions & 11 deletions tool/tctl/common/access_request_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (c *AccessRequestCommand) Initialize(app *kingpin.Application, _ *tctlcfg.G

// TryRun takes the CLI command as an argument (like "access-request list") and executes it.
func (c *AccessRequestCommand) TryRun(ctx context.Context, cmd string, clientFunc commonclient.InitFunc) (match bool, err error) {
var commandFunc func(ctx context.Context, client *authclient.Client) error
var commandFunc func(ctx context.Context, client authclient.ClientI) error
switch cmd {
case c.requestList.FullCommand():
commandFunc = c.List
Expand Down Expand Up @@ -160,7 +160,7 @@ func (c *AccessRequestCommand) TryRun(ctx context.Context, cmd string, clientFun
return true, trace.Wrap(err)
}

func (c *AccessRequestCommand) List(ctx context.Context, client *authclient.Client) error {
func (c *AccessRequestCommand) List(ctx context.Context, client authclient.ClientI) error {
var index proto.AccessRequestSort
switch c.sortIndex {
case "created":
Expand Down Expand Up @@ -203,7 +203,7 @@ func (c *AccessRequestCommand) List(ctx context.Context, client *authclient.Clie
return nil
}

func (c *AccessRequestCommand) Get(ctx context.Context, client *authclient.Client) error {
func (c *AccessRequestCommand) Get(ctx context.Context, client authclient.ClientI) error {
reqs := []types.AccessRequest{}
for _, reqID := range strings.Split(c.reqIDs, ",") {
req, err := client.GetAccessRequests(ctx, types.AccessRequestFilter{
Expand Down Expand Up @@ -258,7 +258,7 @@ func (c *AccessRequestCommand) splitRoles() []string {
return roles
}

func (c *AccessRequestCommand) Approve(ctx context.Context, client *authclient.Client) error {
func (c *AccessRequestCommand) Approve(ctx context.Context, client authclient.ClientI) error {
if c.delegator != "" {
ctx = authz.WithDelegator(ctx, c.delegator)
}
Expand Down Expand Up @@ -289,7 +289,7 @@ func (c *AccessRequestCommand) Approve(ctx context.Context, client *authclient.C
return nil
}

func (c *AccessRequestCommand) Deny(ctx context.Context, client *authclient.Client) error {
func (c *AccessRequestCommand) Deny(ctx context.Context, client authclient.ClientI) error {
if c.delegator != "" {
ctx = authz.WithDelegator(ctx, c.delegator)
}
Expand All @@ -310,7 +310,7 @@ func (c *AccessRequestCommand) Deny(ctx context.Context, client *authclient.Clie
return nil
}

func (c *AccessRequestCommand) Create(ctx context.Context, client *authclient.Client) error {
func (c *AccessRequestCommand) Create(ctx context.Context, client authclient.ClientI) error {
if len(c.roles) == 0 && len(c.requestedResourceIDs) == 0 {
c.roles = "*"
}
Expand All @@ -326,10 +326,10 @@ func (c *AccessRequestCommand) Create(ctx context.Context, client *authclient.Cl

if c.dryRun {
users := &struct {
*authclient.Client
authclient.ClientI
services.UserLoginStatesGetter
}{
Client: client,
ClientI: client,
UserLoginStatesGetter: client.UserLoginStateClient(),
}
err = services.ValidateAccessRequestForUser(ctx, clockwork.NewRealClock(), users, req, tlsca.Identity{}, services.ExpandVars(true))
Expand All @@ -346,7 +346,7 @@ func (c *AccessRequestCommand) Create(ctx context.Context, client *authclient.Cl
return nil
}

func (c *AccessRequestCommand) Delete(ctx context.Context, client *authclient.Client) error {
func (c *AccessRequestCommand) Delete(ctx context.Context, client authclient.ClientI) error {
var approvedTokens []string
for _, reqID := range strings.Split(c.reqIDs, ",") {
// Fetch the requests first to see if they were approved to provide the
Expand Down Expand Up @@ -386,7 +386,7 @@ func (c *AccessRequestCommand) Delete(ctx context.Context, client *authclient.Cl
return nil
}

func (c *AccessRequestCommand) Caps(ctx context.Context, client *authclient.Client) error {
func (c *AccessRequestCommand) Caps(ctx context.Context, client authclient.ClientI) error {
caps, err := client.GetAccessCapabilities(ctx, types.AccessCapabilitiesRequest{
User: c.user,
RequestableRoles: true,
Expand Down Expand Up @@ -422,7 +422,7 @@ func (c *AccessRequestCommand) Caps(ctx context.Context, client *authclient.Clie
}
}

func (c *AccessRequestCommand) Review(ctx context.Context, client *authclient.Client) error {
func (c *AccessRequestCommand) Review(ctx context.Context, client authclient.ClientI) error {
if c.approve == c.deny {
return trace.BadParameter("must supply exactly one of '--approve' or '--deny'")
}
Expand Down
22 changes: 11 additions & 11 deletions tool/tctl/common/accessmonitoring/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (c *Command) initAuditReportsCommands(auditCmd *kingpin.CmdClause, cfg *ser
})
}

type runFunc func(context.Context, *authclient.Client) error
type runFunc func(context.Context, authclient.ClientI) error

func (c *Command) TryRun(ctx context.Context, cmd string, clientFunc commonclient.InitFunc) (match bool, err error) {
handler, ok := c.innerCmdMap[cmd]
Expand All @@ -136,7 +136,7 @@ func (c *Command) TryRun(ctx context.Context, cmd string, clientFunc commonclien
}
}

func (c *cmdHandler) onAuditQueryExec(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditQueryExec(ctx context.Context, authClient authclient.ClientI) error {
if c.auditQuery == "" {
buff, err := io.ReadAll(os.Stdin)
if err != nil {
Expand All @@ -154,7 +154,7 @@ func (c *cmdHandler) onAuditQueryExec(ctx context.Context, authClient *authclien
return nil
}

func (c *cmdHandler) onAuditQueryGet(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditQueryGet(ctx context.Context, authClient authclient.ClientI) error {
auditQuery, err := authClient.SecReportsClient().GetSecurityAuditQuery(ctx, c.name)
if err != nil {
return trace.Wrap(err)
Expand All @@ -165,7 +165,7 @@ func (c *cmdHandler) onAuditQueryGet(ctx context.Context, authClient *authclient
return nil
}

func (c *cmdHandler) onAuditQueryLs(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditQueryLs(ctx context.Context, authClient authclient.ClientI) error {
auditQueries, err := authClient.SecReportsClient().GetSecurityAuditQueries(ctx)
if err != nil {
return trace.Wrap(err)
Expand All @@ -176,14 +176,14 @@ func (c *cmdHandler) onAuditQueryLs(ctx context.Context, authClient *authclient.
return nil
}

func (c *cmdHandler) onAuditQueryRm(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditQueryRm(ctx context.Context, authClient authclient.ClientI) error {
if err := authClient.SecReportsClient().DeleteSecurityAuditQuery(ctx, c.name); err != nil {
return trace.Wrap(err)
}
return nil
}

func (c *cmdHandler) onAuditQuerySchema(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditQuerySchema(ctx context.Context, authClient authclient.ClientI) error {
resp, err := authClient.SecReportsClient().GetSchema(ctx)
if err != nil {
return trace.Wrap(err)
Expand All @@ -201,7 +201,7 @@ func (c *cmdHandler) onAuditQuerySchema(ctx context.Context, authClient *authcli
return nil
}

func (c *cmdHandler) onAuditQueryCreate(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditQueryCreate(ctx context.Context, authClient authclient.ClientI) error {
if c.auditQuery == "" {
return trace.BadParameter("audit query required")
}
Expand All @@ -221,7 +221,7 @@ func (c *cmdHandler) onAuditQueryCreate(ctx context.Context, authClient *authcli
return nil
}

func (c *cmdHandler) onAuditReportLs(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditReportLs(ctx context.Context, authClient authclient.ClientI) error {
reports, err := authClient.SecReportsClient().GetSecurityReports(ctx)
if err != nil {
return trace.Wrap(err)
Expand All @@ -232,7 +232,7 @@ func (c *cmdHandler) onAuditReportLs(ctx context.Context, authClient *authclient
return trace.Wrap(err)
}

func (c *cmdHandler) onAuditReportGet(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditReportGet(ctx context.Context, authClient authclient.ClientI) error {
details, err := authClient.SecReportsClient().GetSecurityReportResult(ctx, c.name, c.days)
if err != nil {
return trace.Wrap(err)
Expand All @@ -243,15 +243,15 @@ func (c *cmdHandler) onAuditReportGet(ctx context.Context, authClient *authclien
return nil
}

func (c *cmdHandler) onAuditReportRun(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditReportRun(ctx context.Context, authClient authclient.ClientI) error {
err := authClient.SecReportsClient().RunSecurityReport(ctx, c.name, c.days)
if err != nil {
return trace.Wrap(err)
}
return nil
}

func (c *cmdHandler) onAuditReportState(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditReportState(ctx context.Context, authClient authclient.ClientI) error {
state, err := authClient.SecReportsClient().GetSecurityReportExecutionState(ctx, c.name, int32(c.days))
if err != nil {
return trace.Wrap(err)
Expand Down
12 changes: 6 additions & 6 deletions tool/tctl/common/acl_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (c *ACLCommand) Initialize(app *kingpin.Application, _ *tctlcfg.GlobalCLIFl

// TryRun takes the CLI command as an argument (like "acl ls") and executes it.
func (c *ACLCommand) TryRun(ctx context.Context, cmd string, clientFunc commonclient.InitFunc) (match bool, err error) {
var commandFunc func(ctx context.Context, client *authclient.Client) error
var commandFunc func(ctx context.Context, client authclient.ClientI) error
switch cmd {
case c.ls.FullCommand():
commandFunc = c.List
Expand All @@ -122,7 +122,7 @@ func (c *ACLCommand) TryRun(ctx context.Context, cmd string, clientFunc commoncl
}

// List will list access lists visible to the user.
func (c *ACLCommand) List(ctx context.Context, client *authclient.Client) error {
func (c *ACLCommand) List(ctx context.Context, client authclient.ClientI) error {
var accessLists []*accesslist.AccessList
var nextKey string
for {
Expand All @@ -149,7 +149,7 @@ func (c *ACLCommand) List(ctx context.Context, client *authclient.Client) error
}

// Get will display information about an access list visible to the user.
func (c *ACLCommand) Get(ctx context.Context, client *authclient.Client) error {
func (c *ACLCommand) Get(ctx context.Context, client authclient.ClientI) error {
accessList, err := client.AccessListClient().GetAccessList(ctx, c.accessListName)
if err != nil {
return trace.Wrap(err)
Expand All @@ -159,7 +159,7 @@ func (c *ACLCommand) Get(ctx context.Context, client *authclient.Client) error {
}

// UsersAdd will add a user to an access list.
func (c *ACLCommand) UsersAdd(ctx context.Context, client *authclient.Client) error {
func (c *ACLCommand) UsersAdd(ctx context.Context, client authclient.ClientI) error {
var expires time.Time
if c.expires != "" {
var err error
Expand Down Expand Up @@ -205,7 +205,7 @@ func (c *ACLCommand) UsersAdd(ctx context.Context, client *authclient.Client) er
}

// UsersRemove will remove a user to an access list.
func (c *ACLCommand) UsersRemove(ctx context.Context, client *authclient.Client) error {
func (c *ACLCommand) UsersRemove(ctx context.Context, client authclient.ClientI) error {
err := client.AccessListClient().DeleteAccessListMember(ctx, c.accessListName, c.userName)
if err != nil {
return trace.Wrap(err)
Expand All @@ -217,7 +217,7 @@ func (c *ACLCommand) UsersRemove(ctx context.Context, client *authclient.Client)
}

// UsersList will list the users in an access list.
func (c *ACLCommand) UsersList(ctx context.Context, client *authclient.Client) error {
func (c *ACLCommand) UsersList(ctx context.Context, client authclient.ClientI) error {
var (
allMembers []*accesslist.AccessListMember
nextToken string
Expand Down
2 changes: 1 addition & 1 deletion tool/tctl/common/admin_action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,7 @@ func runTestCase(t *testing.T, ctx context.Context, client *authclient.Client, t
commandName, err := app.Parse(args)
require.NoError(t, err)

match, err := tc.cliCommand.TryRun(ctx, commandName, func(context.Context) (*authclient.Client, func(context.Context), error) {
match, err := tc.cliCommand.TryRun(ctx, commandName, func(context.Context) (authclient.ClientI, func(context.Context), error) {
return client, func(context.Context) {}, nil
})
require.True(t, match)
Expand Down
12 changes: 6 additions & 6 deletions tool/tctl/common/alert_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (c *AlertCommand) Initialize(app *kingpin.Application, _ *tctlcfg.GlobalCLI

// TryRun takes the CLI command as an argument (like "alerts ls") and executes it.
func (c *AlertCommand) TryRun(ctx context.Context, cmd string, clientFunc commonclient.InitFunc) (match bool, err error) {
var commandFunc func(ctx context.Context, client *authclient.Client) error
var commandFunc func(ctx context.Context, client authclient.ClientI) error
switch cmd {
case c.alertList.FullCommand():
commandFunc = c.List
Expand All @@ -117,7 +117,7 @@ func (c *AlertCommand) TryRun(ctx context.Context, cmd string, clientFunc common
return true, trace.Wrap(err)
}

func (c *AlertCommand) ListAck(ctx context.Context, client *authclient.Client) error {
func (c *AlertCommand) ListAck(ctx context.Context, client authclient.ClientI) error {
acks, err := client.GetAlertAcks(ctx)
if err != nil {
return trace.Wrap(err)
Expand All @@ -135,7 +135,7 @@ func (c *AlertCommand) ListAck(ctx context.Context, client *authclient.Client) e
return nil
}

func (c *AlertCommand) Ack(ctx context.Context, client *authclient.Client) error {
func (c *AlertCommand) Ack(ctx context.Context, client authclient.ClientI) error {
if c.clear {
return c.ClearAck(ctx, client)
}
Expand Down Expand Up @@ -164,7 +164,7 @@ func (c *AlertCommand) Ack(ctx context.Context, client *authclient.Client) error
return nil
}

func (c *AlertCommand) ClearAck(ctx context.Context, client *authclient.Client) error {
func (c *AlertCommand) ClearAck(ctx context.Context, client authclient.ClientI) error {
req := proto.ClearAlertAcksRequest{
AlertID: c.alertID,
}
Expand All @@ -178,7 +178,7 @@ func (c *AlertCommand) ClearAck(ctx context.Context, client *authclient.Client)
return nil
}

func (c *AlertCommand) List(ctx context.Context, client *authclient.Client) error {
func (c *AlertCommand) List(ctx context.Context, client authclient.ClientI) error {
labels, err := libclient.ParseLabelSpec(c.labels)
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -269,7 +269,7 @@ func displayAlertsJSON(alerts []types.ClusterAlert) error {
return nil
}

func (c *AlertCommand) Create(ctx context.Context, client *authclient.Client) error {
func (c *AlertCommand) Create(ctx context.Context, client authclient.ClientI) error {
labels, err := libclient.ParseLabelSpec(c.labels)
if err != nil {
return trace.Wrap(err)
Expand Down
4 changes: 2 additions & 2 deletions tool/tctl/common/app_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (c *AppsCommand) Initialize(app *kingpin.Application, _ *tctlcfg.GlobalCLIF

// TryRun attempts to run subcommands like "apps ls".
func (c *AppsCommand) TryRun(ctx context.Context, cmd string, clientFunc commonclient.InitFunc) (match bool, err error) {
var commandFunc func(ctx context.Context, client *authclient.Client) error
var commandFunc func(ctx context.Context, client authclient.ClientI) error
switch cmd {
case c.appsList.FullCommand():
commandFunc = c.ListApps
Expand All @@ -90,7 +90,7 @@ func (c *AppsCommand) TryRun(ctx context.Context, cmd string, clientFunc commonc

// ListApps prints the list of applications that have recently sent heartbeats
// to the cluster.
func (c *AppsCommand) ListApps(ctx context.Context, clt *authclient.Client) error {
func (c *AppsCommand) ListApps(ctx context.Context, clt authclient.ClientI) error {
labels, err := libclient.ParseLabelSpec(c.labels)
if err != nil {
return trace.Wrap(err)
Expand Down
Loading

0 comments on commit 1d542fc

Please sign in to comment.