diff --git a/clients/go/admin/client.go b/clients/go/admin/client.go index 970dcddba..3669e6e5d 100644 --- a/clients/go/admin/client.go +++ b/clients/go/admin/client.go @@ -9,6 +9,8 @@ import ( "strings" "sync" + "google.golang.org/grpc/backoff" + "github.com/coreos/go-oidc" "github.com/flyteorg/flyteidl/clients/go/admin/mocks" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" @@ -55,10 +57,11 @@ func NewAuthClient(ctx context.Context, conn *grpc.ClientConn) service.AuthServi func GetAdditionalAdminClientConfigOptions(cfg Config) []grpc.DialOption { opts := make([]grpc.DialOption, 0, 2) - backoffConfig := grpc.BackoffConfig{ + backoffConfig := backoff.Config{ MaxDelay: cfg.MaxBackoffDelay.Duration, } - opts = append(opts, grpc.WithBackoffConfig(backoffConfig)) + + opts = append(opts, grpc.WithConnectParams(grpc.ConnectParams{Backoff: backoffConfig})) timeoutDialOption := grpc_retry.WithPerRetryTimeout(cfg.PerRetryTimeout.Duration) maxRetriesOption := grpc_retry.WithMax(uint(cfg.MaxRetries)) @@ -98,18 +101,26 @@ func getTokenEndpointFromAuthServer(ctx context.Context, authorizationServer str // This retrieves a DialOption that contains a source for generating JWTs for authentication with Flyte Admin. // It will first attempt to retrieve the token endpoint by making a metadata call. If that fails, but the token endpoint // is set in the config, that will be used instead. -func getAuthenticationDialOption(ctx context.Context, cfg Config) (grpc.DialOption, error) { - var tokenURL string - tokenURL, err := getTokenEndpointFromAuthServer(ctx, cfg.AuthorizationServerURL) - if err != nil || tokenURL == "" { - logger.Infof(ctx, "No token URL found from configuration Issuer, looking for token endpoint directly") +func getAuthenticationDialOption(ctx context.Context, cfg Config, dialOpts []grpc.DialOption) (grpc.DialOption, error) { + conn, err := grpc.Dial(cfg.Endpoint.String(), dialOpts...) + if err != nil { + return nil, err + } + + tempClient := NewAuthClient(ctx, conn) + tokenURL := cfg.TokenURL + if len(tokenURL) == 0 { + metadata, err := tempClient.OAuth2Metadata(ctx, &service.OAuth2MetadataRequest{}) if err != nil { - logger.Errorf(ctx, "Err is %s", err) - } - tokenURL = cfg.TokenURL - if tokenURL == "" { - return nil, errors.New("no token endpoint could be found") + return nil, fmt.Errorf("failed to fetch auth metadata. Error: %v", err) } + + tokenURL = metadata.TokenEndpoint + } + + clientMetadata, err := tempClient.FlyteClient(ctx, &service.FlyteClientRequest{}) + if err != nil { + return nil, fmt.Errorf("failed to fetch client metadata. Error: %v", err) } secretBytes, err := ioutil.ReadFile(cfg.ClientSecretLocation) @@ -117,21 +128,27 @@ func getAuthenticationDialOption(ctx context.Context, cfg Config) (grpc.DialOpti logger.Errorf(ctx, "Error reading secret from location %s", cfg.ClientSecretLocation) return nil, err } + secret := strings.TrimSpace(string(secretBytes)) + scopes := cfg.Scopes + if len(scopes) == 0 { + scopes = clientMetadata.Scopes + } + ccConfig := clientcredentials.Config{ ClientID: cfg.ClientID, ClientSecret: secret, TokenURL: tokenURL, - Scopes: cfg.Scopes, + Scopes: scopes, } tSource := ccConfig.TokenSource(ctx) - oauthTokenSource := NewCustomHeaderTokenSource(tSource, cfg.AuthorizationHeader) + oauthTokenSource := NewCustomHeaderTokenSource(tSource, clientMetadata.AuthorizationMetadataKey) return grpc.WithPerRPCCredentials(oauthTokenSource), nil } func NewAdminConnection(ctx context.Context, cfg Config) (*grpc.ClientConn, error) { - var opts []grpc.DialOption + opts := GetAdditionalAdminClientConfigOptions(cfg) if cfg.UseInsecureConnection { opts = append(opts, grpc.WithInsecure()) @@ -141,7 +158,7 @@ func NewAdminConnection(ctx context.Context, cfg Config) (*grpc.ClientConn, erro opts = append(opts, grpc.WithTransportCredentials(creds)) if cfg.UseAuth { logger.Infof(ctx, "Instantiating a token source to authenticate against Admin, ID: %s", cfg.ClientID) - jwtDialOption, err := getAuthenticationDialOption(ctx, cfg) + jwtDialOption, err := getAuthenticationDialOption(ctx, cfg, opts) if err != nil { return nil, err } @@ -149,7 +166,6 @@ func NewAdminConnection(ctx context.Context, cfg Config) (*grpc.ClientConn, erro } } - opts = append(opts, GetAdditionalAdminClientConfigOptions(cfg)...) return grpc.Dial(cfg.Endpoint.String(), opts...) } diff --git a/clients/go/admin/config.go b/clients/go/admin/config.go index 4aebc82b4..8518e0f0c 100644 --- a/clients/go/admin/config.go +++ b/clients/go/admin/config.go @@ -2,6 +2,7 @@ package admin import ( "context" + "path/filepath" "time" "github.com/flyteorg/flytestdlib/config" @@ -10,7 +11,12 @@ import ( //go:generate pflags Config --default-var=defaultConfig -const configSectionKey = "admin" +const ( + configSectionKey = "admin" + DefaultClientID = "flytepropeller" +) + +var DefaultClientSecretLocation = filepath.Join(string(filepath.Separator), "etc", "secrets", "client_secret") type Config struct { Endpoint config.URL `json:"endpoint" pflag:",For admin types, specify where the uri of the service is located."` @@ -28,28 +34,34 @@ type Config struct { // There are two ways to get the token URL. If the authorization server url is provided, the client will try to use RFC 8414 to // try to get the token URL. Or it can be specified directly through TokenURL config. - AuthorizationServerURL string `json:"authorizationServerUrl" pflag:",This is the URL to your IDP's authorization server'"` - TokenURL string `json:"tokenUrl" pflag:",Your IDPs token endpoint"` + // Deprecated. This will now be discovered through admin's anonymously accessible metadata. + DeprecatedAuthorizationServerURL string `json:"authorizationServerUrl" pflag:",This is the URL to your IdP's authorization server. It'll default to Endpoint"` + // If not provided, it'll be discovered through admin's anonymously accessible metadata endpoint. + TokenURL string `json:"tokenUrl" pflag:",OPTIONAL: Your IdP's token endpoint."` // See the implementation of the 'grpcAuthorizationHeader' option in Flyte Admin for more information. But // basically we want to be able to use a different string to pass the token from this client to the the Admin service // because things might be running in a service mesh (like Envoy) that already uses the default 'authorization' header - AuthorizationHeader string `json:"authorizationHeader" pflag:",Custom metadata header to pass JWT"` + // Deprecated. It will automatically be discovered through an anonymously accessible auth metadata service. + DeprecatedAuthorizationHeader string `json:"authorizationHeader" pflag:",Custom metadata header to pass JWT"` } var ( defaultConfig = Config{ - MaxBackoffDelay: config.Duration{Duration: 8 * time.Second}, - PerRetryTimeout: config.Duration{Duration: 15 * time.Second}, - MaxRetries: 4, + MaxBackoffDelay: config.Duration{Duration: 8 * time.Second}, + PerRetryTimeout: config.Duration{Duration: 15 * time.Second}, + MaxRetries: 4, + ClientID: DefaultClientID, + ClientSecretLocation: DefaultClientSecretLocation, } + configSection = config.MustRegisterSectionWithUpdates(configSectionKey, &defaultConfig, func(ctx context.Context, newValue config.Config) { if newValue.(*Config).MaxRetries < 0 { logger.Panicf(ctx, "Admin configuration given with negative gRPC retry value.") } if newValue.(*Config).UseAuth { - logger.Warnf(ctx, "Admin client config has authentication ON with server %s", newValue.(*Config).AuthorizationServerURL) + logger.Warnf(ctx, "Admin client config has authentication ON with server %s", newValue.(*Config).DeprecatedAuthorizationServerURL) } }) ) diff --git a/clients/go/admin/config_flags.go b/clients/go/admin/config_flags.go index 541634d86..93d2a836e 100755 --- a/clients/go/admin/config_flags.go +++ b/clients/go/admin/config_flags.go @@ -50,8 +50,8 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "clientId"), defaultConfig.ClientID, "Client ID") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "clientSecretLocation"), defaultConfig.ClientSecretLocation, "File containing the client secret") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "scopes"), []string{}, "List of scopes to request") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authorizationServerUrl"), defaultConfig.AuthorizationServerURL, "This is the URL to your IDP's authorization server'") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authorizationServerUrl"), defaultConfig.DeprecatedAuthorizationServerURL, "This is the URL to your IDP's authorization server'") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "tokenUrl"), defaultConfig.TokenURL, "Your IDPs token endpoint") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authorizationHeader"), defaultConfig.AuthorizationHeader, "Custom metadata header to pass JWT") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authorizationHeader"), defaultConfig.DeprecatedAuthorizationHeader, "Custom metadata header to pass JWT") return cmdFlags } diff --git a/clients/go/admin/config_flags_test.go b/clients/go/admin/config_flags_test.go index 4a7440c8a..9b3b37154 100755 --- a/clients/go/admin/config_flags_test.go +++ b/clients/go/admin/config_flags_test.go @@ -301,7 +301,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("authorizationServerUrl"); err == nil { - assert.Equal(t, string(defaultConfig.AuthorizationServerURL), vString) + assert.Equal(t, string(defaultConfig.DeprecatedAuthorizationServerURL), vString) } else { assert.FailNow(t, err.Error()) } @@ -312,7 +312,7 @@ func TestConfig_SetFlags(t *testing.T) { cmdFlags.Set("authorizationServerUrl", testValue) if vString, err := cmdFlags.GetString("authorizationServerUrl"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AuthorizationServerURL) + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.DeprecatedAuthorizationServerURL) } else { assert.FailNow(t, err.Error()) @@ -345,7 +345,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly if vString, err := cmdFlags.GetString("authorizationHeader"); err == nil { - assert.Equal(t, string(defaultConfig.AuthorizationHeader), vString) + assert.Equal(t, string(defaultConfig.DeprecatedAuthorizationHeader), vString) } else { assert.FailNow(t, err.Error()) } @@ -356,7 +356,7 @@ func TestConfig_SetFlags(t *testing.T) { cmdFlags.Set("authorizationHeader", testValue) if vString, err := cmdFlags.GetString("authorizationHeader"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AuthorizationHeader) + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.DeprecatedAuthorizationHeader) } else { assert.FailNow(t, err.Error()) diff --git a/clients/go/admin/integration_test.go b/clients/go/admin/integration_test.go index 64beb265e..2cd904dc3 100644 --- a/clients/go/admin/integration_test.go +++ b/clients/go/admin/integration_test.go @@ -9,6 +9,8 @@ import ( "testing" "time" + "google.golang.org/grpc" + "golang.org/x/oauth2/clientcredentials" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" @@ -22,14 +24,14 @@ func TestLiveAdminClient(t *testing.T) { u, err := url.Parse("dns:///flyte.lyft.net") assert.NoError(t, err) client := InitializeAdminClient(ctx, Config{ - Endpoint: config.URL{URL: *u}, - UseInsecureConnection: false, - UseAuth: true, - ClientId: "0oacmtueinpXk72Af1t7", - ClientSecretLocation: "/Users/username/.ssh/admin/propeller_secret", - AuthorizationServerURL: "https://lyft.okta.com/oauth2/ausc5wmjw96cRKvTd1t7", - Scopes: []string{"svc"}, - AuthorizationHeader: "Flyte-Authorization", + Endpoint: config.URL{URL: *u}, + UseInsecureConnection: false, + UseAuth: true, + ClientID: "0oacmtueinpXk72Af1t7", + ClientSecretLocation: "/Users/username/.ssh/admin/propeller_secret", + DeprecatedAuthorizationServerURL: "https://lyft.okta.com/oauth2/ausc5wmjw96cRKvTd1t7", + Scopes: []string{"svc"}, + DeprecatedAuthorizationHeader: "Flyte-Authorization", }) resp, err := client.ListProjects(ctx, &admin.ProjectListRequest{}) @@ -52,9 +54,10 @@ func TestGetDialOption(t *testing.T) { ctx := context.Background() cfg := Config{ - AuthorizationServerURL: "https://lyft.okta.com/oauth2/ausc5wmjw96cRKvTd1t7", + DeprecatedAuthorizationServerURL: "https://lyft.okta.com/oauth2/ausc5wmjw96cRKvTd1t7", } - dialOption, err := getAuthenticationDialOption(ctx, cfg) + + dialOption, err := getAuthenticationDialOption(ctx, cfg, []grpc.DialOption{}) assert.NoError(t, err) assert.NotNil(t, dialOption) }