diff --git a/tool/tctl/common/auth_command.go b/tool/tctl/common/auth_command.go index eb4efd832a4eb..1beb0790f2a79 100644 --- a/tool/tctl/common/auth_command.go +++ b/tool/tctl/common/auth_command.go @@ -68,6 +68,9 @@ type AuthCommand struct { leafCluster string kubeCluster string appName string + dbService string + dbName string + dbUser string signOverwrite bool rotateGracePeriod time.Duration @@ -119,7 +122,10 @@ func (a *AuthCommand) Initialize(app *kingpin.Application, config *service.Confi a.authSign.Flag("kube-cluster", `Leaf cluster to generate identity file for when --format is set to "kubernetes"`).Hidden().StringVar(&a.leafCluster) a.authSign.Flag("leaf-cluster", `Leaf cluster to generate identity file for when --format is set to "kubernetes"`).StringVar(&a.leafCluster) a.authSign.Flag("kube-cluster-name", `Kubernetes cluster to generate identity file for when --format is set to "kubernetes"`).StringVar(&a.kubeCluster) - a.authSign.Flag("app-name", `Application to generate identity file for`).StringVar(&a.appName) + a.authSign.Flag("app-name", `Application to generate identity file for. Mutually exclusive with "--db-service".`).StringVar(&a.appName) + a.authSign.Flag("db-service", `Database to generate identity file for. Mutually exclusive with "--app-name".`).StringVar(&a.dbService) + a.authSign.Flag("db-user", `Database user placed on the identity file. Only used when "--db-service" is set.`).StringVar(&a.dbUser) + a.authSign.Flag("db-name", `Database name placed on the identity file. Only used when "--db-service" is set.`).StringVar(&a.dbName) a.authRotate = auth.Command("rotate", "Rotate certificate authorities in the cluster") a.authRotate.Flag("grace-period", "Grace period keeps previous certificate authorities signatures valid, if set to 0 will force users to relogin and nodes to re-register."). @@ -595,10 +601,19 @@ func (a *AuthCommand) generateUserKeys(ctx context.Context, clusterAPI auth.Clie return trace.Wrap(err) } - var routeToApp proto.RouteToApp - var certUsage proto.UserCertsRequest_CertUsage + var ( + routeToApp proto.RouteToApp + routeToDatabase proto.RouteToDatabase + certUsage proto.UserCertsRequest_CertUsage + ) + + // `appName` and `db` are mutually exclusive. + if a.appName != "" && a.dbService != "" { + return trace.BadParameter("only --app-name or --db-service can be set, not both") + } - if a.appName != "" { + switch { + case a.appName != "": server, err := getApplicationServer(ctx, clusterAPI, a.appName) if err != nil { return trace.Wrap(err) @@ -620,6 +635,19 @@ func (a *AuthCommand) generateUserKeys(ctx context.Context, clusterAPI auth.Clie SessionID: appSession.GetName(), } certUsage = proto.UserCertsRequest_App + case a.dbService != "": + server, err := getDatabaseServer(context.TODO(), clusterAPI, a.dbService) + if err != nil { + return trace.Wrap(err) + } + + routeToDatabase = proto.RouteToDatabase{ + ServiceName: a.dbService, + Protocol: server.GetDatabase().GetProtocol(), + Database: a.dbName, + Username: a.dbUser, + } + certUsage = proto.UserCertsRequest_Database } reqExpiry := time.Now().UTC().Add(a.genTTL) @@ -633,6 +661,7 @@ func (a *AuthCommand) generateUserKeys(ctx context.Context, clusterAPI auth.Clie KubernetesCluster: a.kubeCluster, RouteToApp: routeToApp, Usage: certUsage, + RouteToDatabase: routeToDatabase, }) if err != nil { return trace.Wrap(err) @@ -833,3 +862,20 @@ func getApplicationServer(ctx context.Context, clusterAPI auth.ClientI, appName } return nil, trace.NotFound("app %q not found", appName) } + +// getDatabaseServer fetches a single `DatabaseServer` by name using the +// provided `auth.ClientI`. +func getDatabaseServer(ctx context.Context, clientAPI auth.ClientI, dbName string) (types.DatabaseServer, error) { + servers, err := clientAPI.GetDatabaseServers(ctx, apidefaults.Namespace) + if err != nil { + return nil, trace.Wrap(err) + } + + for _, server := range servers { + if server.GetName() == dbName { + return server, nil + } + } + + return nil, trace.NotFound("database %q not found", dbName) +} diff --git a/tool/tctl/common/auth_command_test.go b/tool/tctl/common/auth_command_test.go index 60fad9eda1f91..1694dff2994a4 100644 --- a/tool/tctl/common/auth_command_test.go +++ b/tool/tctl/common/auth_command_test.go @@ -32,6 +32,7 @@ import ( "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/client/identityfile" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/kube/kubeconfig" "github.com/gravitational/teleport/lib/service" "github.com/gravitational/teleport/lib/services" @@ -278,6 +279,7 @@ type mockClient struct { remoteClusters []types.RemoteCluster kubeServices []types.Server appServices []types.AppServer + dbServices []types.DatabaseServer appSession types.WebSession } @@ -313,6 +315,10 @@ func (c *mockClient) CreateAppSession(ctx context.Context, req types.CreateAppSe return c.appSession, nil } +func (c *mockClient) GetDatabaseServers(context.Context, string, ...services.MarshalOption) ([]types.DatabaseServer, error) { + return c.dbServices, nil +} + func TestCheckKubeCluster(t *testing.T) { const teleportCluster = "local-teleport" clusterName, err := services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{ @@ -673,3 +679,138 @@ func TestGenerateAppCertificates(t *testing.T) { }) } } + +func TestGenerateDatabaseUserCertificates(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { + clusterName string + dbService string + dbName string + dbUser string + expectedDbProtocol string + dbServices []types.DatabaseServer + expectedErr error + }{ + "DatabaseExists": { + clusterName: "example.com", + dbService: "db-1", + expectedDbProtocol: defaults.ProtocolPostgres, + dbServices: []types.DatabaseServer{ + &types.DatabaseServerV3{ + Metadata: types.Metadata{ + Name: "db-1", + }, + Spec: types.DatabaseServerSpecV3{ + Hostname: "example.com", + Database: &types.DatabaseV3{ + Spec: types.DatabaseSpecV3{ + Protocol: defaults.ProtocolPostgres, + }, + }, + }, + }, + }, + }, + "DatabaseWithUserExists": { + clusterName: "example.com", + dbService: "db-user-1", + dbUser: "mongo-user", + expectedDbProtocol: defaults.ProtocolMongoDB, + dbServices: []types.DatabaseServer{ + &types.DatabaseServerV3{ + Metadata: types.Metadata{ + Name: "db-user-1", + }, + Spec: types.DatabaseServerSpecV3{ + Hostname: "example.com", + Database: &types.DatabaseV3{ + Spec: types.DatabaseSpecV3{ + Protocol: defaults.ProtocolMongoDB, + }, + }, + }, + }, + }, + }, + "DatabaseWithDatabaseNameExists": { + clusterName: "example.com", + dbService: "db-user-1", + dbName: "root-database", + expectedDbProtocol: defaults.ProtocolMongoDB, + dbServices: []types.DatabaseServer{ + &types.DatabaseServerV3{ + Metadata: types.Metadata{ + Name: "db-user-1", + }, + Spec: types.DatabaseServerSpecV3{ + Hostname: "example.com", + Database: &types.DatabaseV3{ + Spec: types.DatabaseSpecV3{ + Protocol: defaults.ProtocolMongoDB, + }, + }, + }, + }, + }, + }, + "DatabaseNotFound": { + clusterName: "example.com", + dbService: "db-2", + dbServices: []types.DatabaseServer{}, + expectedErr: trace.NotFound(""), + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + clusterName, err := services.NewClusterNameWithRandomID( + types.ClusterNameSpecV2{ + ClusterName: test.clusterName, + }) + require.NoError(t, err) + + authClient := &mockClient{ + clusterName: clusterName, + userCerts: &proto.Certs{ + SSH: []byte("SSH cert"), + TLS: []byte("TLS cert"), + }, + dbServices: test.dbServices, + } + + certsDir := t.TempDir() + output := filepath.Join(certsDir, test.dbService) + ac := AuthCommand{ + output: output, + outputFormat: identityfile.FormatTLS, + signOverwrite: true, + genTTL: time.Hour, + dbService: test.dbService, + dbName: test.dbName, + dbUser: test.dbUser, + } + + err = ac.generateUserKeys(ctx, authClient) + if test.expectedErr != nil { + require.Error(t, err) + require.IsType(t, test.expectedErr, err) + return + } + + require.NoError(t, err) + + expectedRouteToDatabase := proto.RouteToDatabase{ + ServiceName: test.dbService, + Protocol: test.expectedDbProtocol, + Database: test.dbName, + Username: test.dbUser, + } + require.Equal(t, proto.UserCertsRequest_Database, authClient.userCertsReq.Usage) + require.Equal(t, expectedRouteToDatabase, authClient.userCertsReq.RouteToDatabase) + + certBytes, err := os.ReadFile(filepath.Join(certsDir, test.dbService+".crt")) + require.NoError(t, err) + require.Equal(t, authClient.userCerts.TLS, certBytes, "certificates match") + }) + } +}