Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: SSO MFA - Add ephemeral SSO MFA device #46704

Merged
merged 14 commits into from
Oct 16, 2024
9 changes: 9 additions & 0 deletions api/proto/teleport/legacy/types/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3736,6 +3736,7 @@ message MFADevice {
TOTPDevice totp = 8;
U2FDevice u2f = 9;
WebauthnDevice webauthn = 10;
SSOMFADevice sso = 11;
}
}

Expand Down Expand Up @@ -3800,6 +3801,14 @@ message WebauthnDevice {
google.protobuf.BoolValue credential_backed_up = 10;
}

// SSOMFADevice contains details of an SSO MFA method.
message SSOMFADevice {
// connector_id is the ID of the SSO connector.
string connector_id = 1;
// connector_type is the type of the SSO connector.
string connector_type = 2;
}

// WebauthnLocalAuth holds settings necessary for local webauthn use.
message WebauthnLocalAuth {
// UserID is the random user handle generated for the user.
Expand Down
104 changes: 0 additions & 104 deletions api/types/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package types

import (
"bytes"
"context"
"encoding/json"
"fmt"
Expand All @@ -27,7 +26,6 @@ import (
"strings"
"time"

"github.com/gogo/protobuf/jsonpb"
"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/constants"
Expand Down Expand Up @@ -939,108 +937,6 @@ func (wal *WebauthnLocalAuth) Check() error {
return nil
}

// NewMFADevice creates a new MFADevice with the given name. Caller must set
// the Device field in the returned MFADevice.
func NewMFADevice(name, id string, addedAt time.Time) *MFADevice {
return &MFADevice{
Metadata: Metadata{
Name: name,
},
Id: id,
AddedAt: addedAt,
LastUsed: addedAt,
}
}

// setStaticFields sets static resource header and metadata fields.
func (d *MFADevice) setStaticFields() {
d.Kind = KindMFADevice
d.Version = V1
}

// CheckAndSetDefaults validates MFADevice fields and populates empty fields
// with default values.
func (d *MFADevice) CheckAndSetDefaults() error {
d.setStaticFields()
if err := d.Metadata.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}
if d.Id == "" {
return trace.BadParameter("MFADevice missing ID field")
}
if d.AddedAt.IsZero() {
return trace.BadParameter("MFADevice missing AddedAt field")
}
if d.LastUsed.IsZero() {
return trace.BadParameter("MFADevice missing LastUsed field")
}
if d.LastUsed.Before(d.AddedAt) {
return trace.BadParameter("MFADevice LastUsed field must be earlier than AddedAt")
}
if d.Device == nil {
return trace.BadParameter("MFADevice missing Device field")
}
if err := checkWebauthnDevice(d); err != nil {
return trace.Wrap(err)
}
return nil
}

func checkWebauthnDevice(d *MFADevice) error {
wrapper, ok := d.Device.(*MFADevice_Webauthn)
if !ok {
return nil
}
switch webDev := wrapper.Webauthn; {
case webDev == nil:
return trace.BadParameter("MFADevice has malformed WebauthnDevice")
case len(webDev.CredentialId) == 0:
return trace.BadParameter("WebauthnDevice missing CredentialId field")
case len(webDev.PublicKeyCbor) == 0:
return trace.BadParameter("WebauthnDevice missing PublicKeyCbor field")
default:
return nil
}
}

func (d *MFADevice) GetKind() string { return d.Kind }
func (d *MFADevice) GetSubKind() string { return d.SubKind }
func (d *MFADevice) SetSubKind(sk string) { d.SubKind = sk }
func (d *MFADevice) GetVersion() string { return d.Version }
func (d *MFADevice) GetMetadata() Metadata { return d.Metadata }
func (d *MFADevice) GetName() string { return d.Metadata.GetName() }
func (d *MFADevice) SetName(n string) { d.Metadata.SetName(n) }
func (d *MFADevice) GetRevision() string { return d.Metadata.GetRevision() }
func (d *MFADevice) SetRevision(rev string) { d.Metadata.SetRevision(rev) }
func (d *MFADevice) Expiry() time.Time { return d.Metadata.Expiry() }
func (d *MFADevice) SetExpiry(exp time.Time) { d.Metadata.SetExpiry(exp) }

// MFAType returns the human-readable name of the MFA protocol of this device.
func (d *MFADevice) MFAType() string {
switch d.Device.(type) {
case *MFADevice_Totp:
return "TOTP"
case *MFADevice_U2F:
return "U2F"
case *MFADevice_Webauthn:
return "WebAuthn"
default:
return "unknown"
}
}

func (d *MFADevice) MarshalJSON() ([]byte, error) {
buf := new(bytes.Buffer)
err := (&jsonpb.Marshaler{}).Marshal(buf, d)
return buf.Bytes(), trace.Wrap(err)
}

func (d *MFADevice) UnmarshalJSON(buf []byte) error {
unmarshaler := jsonpb.Unmarshaler{AllowUnknownFields: true}
err := unmarshaler.Unmarshal(bytes.NewReader(buf), d)
return trace.Wrap(err)
}

// IsSessionMFARequired returns whether this RequireMFAType requires per-session MFA.
func (r RequireMFAType) IsSessionMFARequired() bool {
return r != RequireMFAType_OFF
Expand Down
6 changes: 5 additions & 1 deletion api/types/authentication_mfadevice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ func TestMFADevice_CheckAndSetDefaults(t *testing.T) {
Id: "otp-0001",
AddedAt: now,
LastUsed: now,
Device: &types.MFADevice_Totp{}, // validated elsewhere
Device: &types.MFADevice_Totp{
Totp: &types.TOTPDevice{
Key: "key",
},
},
},
},
{
Expand Down
134 changes: 134 additions & 0 deletions api/types/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,103 @@
package types

import (
"bytes"
"time"

"github.com/gogo/protobuf/jsonpb"
"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/utils"
)

// NewMFADevice creates a new MFADevice with the given name. Caller must set
// the Device field in the returned MFADevice.
func NewMFADevice(name, id string, addedAt time.Time, device isMFADevice_Device) (*MFADevice, error) {
dev := &MFADevice{
Metadata: Metadata{
Name: name,
},
Id: id,
AddedAt: addedAt,
LastUsed: addedAt,
Device: device,
}
return dev, dev.CheckAndSetDefaults()
Joerger marked this conversation as resolved.
Show resolved Hide resolved
}

// setStaticFields sets static resource header and metadata fields.
func (d *MFADevice) setStaticFields() {
d.Kind = KindMFADevice
d.Version = V1
}

// CheckAndSetDefaults validates MFADevice fields and populates empty fields
// with default values.
func (d *MFADevice) CheckAndSetDefaults() error {
codingllama marked this conversation as resolved.
Show resolved Hide resolved
d.setStaticFields()
if err := d.Metadata.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}
if d.Id == "" {
return trace.BadParameter("MFADevice missing ID field")
}
if d.AddedAt.IsZero() {
return trace.BadParameter("MFADevice missing AddedAt field")
}
if d.LastUsed.IsZero() {
return trace.BadParameter("MFADevice missing LastUsed field")
}
if d.LastUsed.Before(d.AddedAt) {
return trace.BadParameter("MFADevice LastUsed field must be earlier than AddedAt")
}
if d.Device == nil {
return trace.BadParameter("MFADevice missing Device field")
}
if err := d.validateDevice(); err != nil {
return trace.Wrap(err)
}
return nil
}

// validateDevice runs additional validations for OTP devices.
// Prefer adding new validation logic to types.MFADevice.CheckAndSetDefaults
// instead.
func (d *MFADevice) validateDevice() error {
switch dev := d.Device.(type) {
case *MFADevice_Totp:
if dev.Totp == nil {
return trace.BadParameter("MFADevice has malformed TOTPDevice")
}
if dev.Totp.Key == "" {
return trace.BadParameter("TOTPDevice missing Key field")
}
case *MFADevice_Webauthn:
if dev.Webauthn == nil {
return trace.BadParameter("MFADevice has malformed WebauthnDevice")
}
if len(dev.Webauthn.CredentialId) == 0 {
return trace.BadParameter("WebauthnDevice missing CredentialId field")
}
if len(dev.Webauthn.PublicKeyCbor) == 0 {
return trace.BadParameter("WebauthnDevice missing PublicKeyCbor field")
}
case *MFADevice_Sso:
if dev.Sso == nil {
return trace.BadParameter("MFADevice has malformed SSODevice")
}
if dev.Sso.ConnectorId == "" {
return trace.BadParameter("SSODevice missing ConnectorId field")
}
if dev.Sso.ConnectorType == "" {
return trace.BadParameter("SSODevice missing ConnectorType field")
}
case *MFADevice_U2F:
default:
return trace.BadParameter("MFADevice has Device field of unknown type %T", dev)
}
return nil
}

func (d *MFADevice) WithoutSensitiveData() (*MFADevice, error) {
if d == nil {
return nil, trace.BadParameter("cannot hide sensitive data on empty object")
Expand All @@ -33,9 +125,51 @@ func (d *MFADevice) WithoutSensitiveData() (*MFADevice, error) {
// OK, no sensitive secrets.
case *MFADevice_Webauthn:
// OK, no sensitive secrets.
case *MFADevice_Sso:
// OK, no sensitive secrets.
default:
return nil, trace.BadParameter("unsupported MFADevice type %T", d.Device)
}

return out, nil
}

func (d *MFADevice) GetKind() string { return d.Kind }
func (d *MFADevice) GetSubKind() string { return d.SubKind }
func (d *MFADevice) SetSubKind(sk string) { d.SubKind = sk }
func (d *MFADevice) GetVersion() string { return d.Version }
func (d *MFADevice) GetMetadata() Metadata { return d.Metadata }
func (d *MFADevice) GetName() string { return d.Metadata.GetName() }
func (d *MFADevice) SetName(n string) { d.Metadata.SetName(n) }
func (d *MFADevice) GetRevision() string { return d.Metadata.GetRevision() }
func (d *MFADevice) SetRevision(rev string) { d.Metadata.SetRevision(rev) }
func (d *MFADevice) Expiry() time.Time { return d.Metadata.Expiry() }
func (d *MFADevice) SetExpiry(exp time.Time) { d.Metadata.SetExpiry(exp) }

// MFAType returns the human-readable name of the MFA protocol of this device.
func (d *MFADevice) MFAType() string {
switch d.Device.(type) {
case *MFADevice_Totp:
return "TOTP"
case *MFADevice_U2F:
return "U2F"
case *MFADevice_Webauthn:
return "WebAuthn"
case *MFADevice_Sso:
return "SSO"
default:
return "unknown"
}
}

func (d *MFADevice) MarshalJSON() ([]byte, error) {
buf := new(bytes.Buffer)
err := (&jsonpb.Marshaler{}).Marshal(buf, d)
return buf.Bytes(), trace.Wrap(err)
}

func (d *MFADevice) UnmarshalJSON(buf []byte) error {
unmarshaler := jsonpb.Unmarshaler{AllowUnknownFields: true}
err := unmarshaler.Unmarshal(bytes.NewReader(buf), d)
return trace.Wrap(err)
}
Loading
Loading