diff --git a/docs/configuration/configuration.md b/docs/configuration/configuration.md index 3a107b808..2bd254b90 100644 --- a/docs/configuration/configuration.md +++ b/docs/configuration/configuration.md @@ -14,12 +14,12 @@ Config file expects the keys to have the exact naming as the flags. Any URI passed to flagd via the `--uri` flag must follow one of the 4 following patterns to ensure that it is passed to the correct implementation: -| Sync | Pattern | Example | -|------------|------------------------------------|---------------------------------------| +| Sync | Pattern | Example | +|------------|---------------------------------------|---------------------------------------| | Kubernetes | `core.openfeature.dev/namespace/name` | `core.openfeature.dev/default/my-crd` | -| Filepath | `file:path/to/my/flag` | `file:etc/flagd/my-flags.json` | -| Remote | `http(s)://flag-source-url` | `https://my-flags.com/flags` | -| Grpc | `grpc://flag-source-url` | `grpc://my-flags-server` | +| Filepath | `file:path/to/my/flag` | `file:etc/flagd/my-flags.json` | +| Remote | `http(s)://flag-source-url` | `https://my-flags.com/flags` | +| Grpc | `grpc(s)://flag-source-url` | `grpc://my-flags-server` | ### Customising sync providers @@ -42,11 +42,12 @@ While a URI may be passed to flagd via the `--uri` flag, some implementations ma The flag takes a string argument, which should be a JSON representation of an array of `SourceConfig` objects. Alternatively, these configurations should be passed to flagd via config file, specified using the `--config` flag. -| Field | Type | -|------------|------------------------------------| -| uri | required `string` | | -| provider | required `string` (`file`, `kubernetes`, `http` or `grpc`) | -| bearerToken | optional `string` | +| Field | Type | Note | +|-------------|------------------------------------------------------------|----------------------------------------------------| +| uri | required `string` | | +| provider | required `string` (`file`, `kubernetes`, `http` or `grpc`) | | +| bearerToken | optional `string` | Used for http sync | +| certPath | optional `string` | Used for grpcs sync when TLS certificate is needed | The `uri` field values do not need to follow the [URI patterns](#uri-patterns), the provider type is instead derived from the provider field. If the prefix is supplied, it will be removed on startup without error. @@ -68,4 +69,7 @@ sources: provider: kubernetes - uri: grpc://my-flag-source:8080 provider: grpc +- uri: grpcs://my-flag-source:8080 + provider: grpc + certPath: /certs/ca.cert ``` diff --git a/pkg/runtime/from_config.go b/pkg/runtime/from_config.go index 2e86b9052..590911213 100644 --- a/pkg/runtime/from_config.go +++ b/pkg/runtime/from_config.go @@ -30,16 +30,18 @@ const ( ) var ( - regCrd *regexp.Regexp - regURL *regexp.Regexp - regGRPC *regexp.Regexp - regFile *regexp.Regexp + regCrd *regexp.Regexp + regURL *regexp.Regexp + regGRPC *regexp.Regexp + regGRPCSecure *regexp.Regexp + regFile *regexp.Regexp ) func init() { regCrd = regexp.MustCompile("^core.openfeature.dev/") regURL = regexp.MustCompile("^https?://") regGRPC = regexp.MustCompile("^" + grpc.Prefix) + regGRPCSecure = regexp.MustCompile("^" + grpc.PrefixSecure) regFile = regexp.MustCompile("^file:") } @@ -120,11 +122,12 @@ func (r *Runtime) setSyncImplFromConfig(logger *logger.Logger) error { func (r *Runtime) newGRPC(config sync.SourceConfig, logger *logger.Logger) *grpc.Sync { return &grpc.Sync{ - Target: grpc.URLToGRPCTarget(config.URI), + URI: config.URI, Logger: logger.WithFields( zap.String("component", "sync"), zap.String("sync", "grpc"), ), + CertPath: config.CertPath, } } @@ -211,7 +214,7 @@ func SyncProvidersFromURIs(uris []string) ([]sync.SourceConfig, error) { URI: uri, Provider: syncProviderHTTP, }) - case regGRPC.Match(uriB): + case regGRPC.Match(uriB), regGRPCSecure.Match(uriB): syncProvidersParsed = append(syncProvidersParsed, sync.SourceConfig{ URI: uri, Provider: syncProviderGrpc, diff --git a/pkg/sync/grpc/grpc_sync.go b/pkg/sync/grpc/grpc_sync.go index adc667596..87daecb34 100644 --- a/pkg/sync/grpc/grpc_sync.go +++ b/pkg/sync/grpc/grpc_sync.go @@ -2,12 +2,17 @@ package grpc import ( "context" + "crypto/tls" + "crypto/x509" "fmt" "math" + "os" "strings" msync "sync" "time" + "google.golang.org/grpc/credentials" + "buf.build/gen/go/open-feature/flagd/grpc/go/sync/v1/syncv1grpc" v1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/sync/v1" @@ -18,9 +23,10 @@ import ( ) const ( - // Prefix for GRPC URL inputs. GRPC does not define a prefix through standard. This prefix helps to differentiate - // remote URLs for REST APIs (i.e - HTTP) from GRPC endpoints. - Prefix = "grpc://" + // Prefix for GRPC URL inputs. GRPC does not define a standard prefix. This prefix helps to differentiate remote + // URLs for REST APIs (i.e - HTTP) from GRPC endpoints. + Prefix = "grpc://" + PrefixSecure = "grpcs://" // Connection retry constants // Back off period is calculated with backOffBase ^ #retry-iteration. However, when #retry-iteration count reach @@ -28,37 +34,44 @@ const ( backOffLimit = 3 backOffBase = 4 constantBackOffDelay = 60 + + tlsVersion = tls.VersionTLS12 ) var once msync.Once type Sync struct { - Target string + URI string ProviderID string + CertPath string Logger *logger.Logger - Mux *msync.RWMutex - syncClient syncv1grpc.FlagSyncService_SyncFlagsClient - client syncv1grpc.FlagSyncServiceClient - options []grpc.DialOption - ready bool + client syncv1grpc.FlagSyncServiceClient + ready bool } -func (g *Sync) connectClient(ctx context.Context) error { - // initial dial and connection. Failure here must result in a startup failure - dial, err := grpc.DialContext(ctx, g.Target, g.options...) +func (g *Sync) Init(ctx context.Context) error { + tCredentials, err := buildTransportCredentials(g.URI, g.CertPath) if err != nil { + g.Logger.Error(fmt.Sprintf("error building transport credentials: %s", err.Error())) return err } - g.client = syncv1grpc.NewFlagSyncServiceClient(dial) + target, ok := sourceToGRPCTarget(g.URI) + if !ok { + return fmt.Errorf("invalid grpc source: %s", g.URI) + } - syncClient, err := g.client.SyncFlags(ctx, &v1.SyncFlagsRequest{ProviderId: g.ProviderID}) + // Derive reusable client connection + rpcCon, err := grpc.DialContext(ctx, target, grpc.WithTransportCredentials(tCredentials)) if err != nil { - g.Logger.Error(fmt.Sprintf("error calling streaming operation: %s", err.Error())) + g.Logger.Error(fmt.Sprintf("error initiating grpc client connection: %s", err.Error())) return err } - g.syncClient = syncClient + + // Setup service client + g.client = syncv1grpc.NewFlagSyncServiceClient(rpcCon) + return nil } @@ -70,30 +83,28 @@ func (g *Sync) ReSync(ctx context.Context, dataSync chan<- sync.DataSync) error } dataSync <- sync.DataSync{ FlagData: res.GetFlagConfiguration(), - Source: g.Target, + Source: g.URI, Type: sync.ALL, } return nil } -func (g *Sync) Init(ctx context.Context) error { - g.options = []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), - } - - // initial dial and connection. Failure here must result in a startup failure - return g.connectClient(ctx) -} - func (g *Sync) IsReady() bool { return g.ready } func (g *Sync) Sync(ctx context.Context, dataSync chan<- sync.DataSync) error { - // initial stream listening - err := g.handleFlagSync(g.syncClient, dataSync) + // Initialize SyncFlags client. This fails if server connection establishment fails (ex:- grpc server offline) + syncClient, err := g.client.SyncFlags(ctx, &v1.SyncFlagsRequest{ProviderId: g.ProviderID}) + if err != nil { + return err + } + + // Initial stream listening. Error will be logged and continue and retry connection establishment + err = g.handleFlagSync(syncClient, dataSync) if err == nil { - return nil + // This should not happen as handleFlagSync expects to return with an error + return err } g.Logger.Warn(fmt.Sprintf("error with stream listener: %s", err.Error())) @@ -141,12 +152,7 @@ func (g *Sync) connectWithRetry( return nil, false } - g.Logger.Warn(fmt.Sprintf("connection re-establishment attempt in-progress for grpc target: %s", g.Target)) - - if err := g.connectClient(ctx); err != nil { - g.Logger.Debug(fmt.Sprintf("error dialing target: %s", err.Error())) - continue - } + g.Logger.Warn(fmt.Sprintf("connection re-establishment attempt in-progress for grpc target: %s", g.URI)) syncClient, err := g.client.SyncFlags(ctx, &v1.SyncFlagsRequest{ProviderId: g.ProviderID}) if err != nil { @@ -154,7 +160,7 @@ func (g *Sync) connectWithRetry( continue } - g.Logger.Info(fmt.Sprintf("connection re-established with grpc target: %s", g.Target)) + g.Logger.Info(fmt.Sprintf("connection re-established with grpc target: %s", g.URI)) return syncClient, true } } @@ -176,7 +182,7 @@ func (g *Sync) handleFlagSync(stream syncv1grpc.FlagSyncService_SyncFlagsClient, case v1.SyncState_SYNC_STATE_ALL: dataSync <- sync.DataSync{ FlagData: data.FlagConfiguration, - Source: g.Target, + Source: g.URI, Type: sync.ALL, } @@ -184,7 +190,7 @@ func (g *Sync) handleFlagSync(stream syncv1grpc.FlagSyncService_SyncFlagsClient, case v1.SyncState_SYNC_STATE_ADD: dataSync <- sync.DataSync{ FlagData: data.FlagConfiguration, - Source: g.Target, + Source: g.URI, Type: sync.ADD, } @@ -192,7 +198,7 @@ func (g *Sync) handleFlagSync(stream syncv1grpc.FlagSyncService_SyncFlagsClient, case v1.SyncState_SYNC_STATE_UPDATE: dataSync <- sync.DataSync{ FlagData: data.FlagConfiguration, - Source: g.Target, + Source: g.URI, Type: sync.UPDATE, } @@ -200,7 +206,7 @@ func (g *Sync) handleFlagSync(stream syncv1grpc.FlagSyncService_SyncFlagsClient, case v1.SyncState_SYNC_STATE_DELETE: dataSync <- sync.DataSync{ FlagData: data.FlagConfiguration, - Source: g.Target, + Source: g.URI, Type: sync.DELETE, } @@ -213,14 +219,57 @@ func (g *Sync) handleFlagSync(stream syncv1grpc.FlagSyncService_SyncFlagsClient, } } -// URLToGRPCTarget is a helper to derive GRPC target from a provided URL +// buildTransportCredentials is a helper to build grpc credentials.TransportCredentials based on source and cert path +func buildTransportCredentials(source string, certPath string) (credentials.TransportCredentials, error) { + if strings.Contains(source, Prefix) { + return insecure.NewCredentials(), nil + } + + if !strings.Contains(source, PrefixSecure) { + return nil, fmt.Errorf("invalid source. grpc source must contain prefix %s or %s", Prefix, PrefixSecure) + } + + if certPath == "" { + // Rely on CA certs provided from system + return credentials.NewTLS(&tls.Config{MinVersion: tlsVersion}), nil + } + + // Rely on provided certificate + certBytes, err := os.ReadFile(certPath) + if err != nil { + return nil, err + } + + cp := x509.NewCertPool() + if !cp.AppendCertsFromPEM(certBytes) { + return nil, fmt.Errorf("invalid certificate provided at path: %s", certPath) + } + + return credentials.NewTLS(&tls.Config{ + MinVersion: tlsVersion, + RootCAs: cp, + }), nil +} + +// sourceToGRPCTarget is a helper to derive GRPC target from a provided URL // For example, function returns the target localhost:9090 for the input grpc://localhost:9090 -func URLToGRPCTarget(url string) string { - index := strings.Split(url, Prefix) +func sourceToGRPCTarget(url string) (string, bool) { + var separator string + + switch { + case strings.Contains(url, Prefix): + separator = Prefix + case strings.Contains(url, PrefixSecure): + separator = PrefixSecure + default: + return "", false + } + + index := strings.Split(url, separator) - if len(index) == 2 { - return index[1] + if len(index) == 2 && len(index[1]) != 0 { + return index[1], true } - return index[0] + return "", false } diff --git a/pkg/sync/grpc/grpc_sync_test.go b/pkg/sync/grpc/grpc_sync_test.go index 735031546..0a61488a4 100644 --- a/pkg/sync/grpc/grpc_sync_test.go +++ b/pkg/sync/grpc/grpc_sync_test.go @@ -7,7 +7,11 @@ import ( "io" "log" "net" + "os" "testing" + "time" + + "golang.org/x/sync/errgroup" "buf.build/gen/go/open-feature/flagd/grpc/go/sync/v1/syncv1grpc" v1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/sync/v1" @@ -19,6 +23,34 @@ import ( "google.golang.org/grpc/test/bufconn" ) +const sampleCert = `-----BEGIN CERTIFICATE----- +MIIEnDCCAoQCCQCHcl3hGXwRQzANBgkqhkiG9w0BAQsFADAQMQ4wDAYDVQQDDAVm +bGFnZDAeFw0yMzAyMTAxODM1NDVaFw0zMzAyMDcxODM1NDVaMBAxDjAMBgNVBAMM +BWZsYWdkMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAwDLEAUti/kG9 +MhJLtO7oAy7diHxWKDFmsIHrE+z2IzTxjXxVHQLv1HiYB/UN75y7qlb3MwvzSc+C +BoLuoiM0PDiMio9/o9X5j0U+v3H1JpUU5LardkvsprFqJWmHF+D7aRdM0LBLn2X6 +HQOhSnPyH9Qjl2l2tyPiPTZ6g0i2+rXZsNUoTs4fm6ThhZ0LeXR8KDmCTun3ze1d +hXA7ydxwILH2OVc+Wnzl30+BRvOiLQbc9nYnwSREFeIy8sFbhrTHqSNn3eY79ssZ +T6f4tN3jEV1d7NqoFk9KFLJKJhMt7smMB9NLwVWi581Zj1krYirNlP6mtmPrn3kJ +lsgT15kFftShMVcYFSHqOSLiy4SspHGK8KJaFoEVx0wp/weRwrWXi6vWg7tuHATH +fw7gW/9CyV+ylc0pJ002wtPAgzJYUaOrna0R2r3yQsSzRcDnqsm4FLkPHLoyjrwQ +vshKcEqjhGml1M+lTDEo3RO5ZoQ3ZN2AZKPDrK2zGG4wFJjHRu9FtutOEZkYYOzA +emTQWW8US3q8WVQqGl/EwQqzXk9Lco7uhLdXmqVOvAi6z01gehQJPnjhH7iqAPVp +1tlOBHit1F3sTAQIO/2zff3LCKiD2d27KINh4aFEyDbDmglPA8VPO3BMQVSjFlxj +K1s2G1IDBixXK76VmBP+ZpvxOaQtYIUCAwEAATANBgkqhkiG9w0BAQsFAAOCAgEA +K9+wnl5gpkfNBa+OSxlhOn3CKhcaW/SWZ4aLw2yK1NZNnNjpwUcLQScUDBKDoJJR +5roc3PIImX7hdnobZWqFhD23laaAlu5XLk9P7n51uMEiNjQQc2WaaBZDTRJfki1C +MvPskXqptgPsVyuPJc0DxfaCz7pDYjq/CtJ+osaj404P5mlO1QJ8W91QSx+aq2x4 +uUTUWuyr/8flIcxiX0o8VTb2LcUvWZBMGa3CdeLnPHrOjovfjJFy0Ysk3SGEACLL +9mpbNbv23v9UXVfyFffHpyzvyUJIOsNXG0O1AYf5t9bukqHolGR/RQUN4yGd3M62 +mFR5bOST36DjNSzTrx1eyCLv22+h9VVlWFPrebFnq1W5SSi8PtsGSMjhvX7dB1kS +t0yJtlj2HwBAvI1zVKG76q6neSU51UXFQUbO0OA0sxjicEOlNfXnShM/kY2lobpX +hrCysWpqoSS0S3UBvmuRiraLWkP1KueC0XHoAi8yuwMAdM6Y+h2OJpnO0PdpUmrp +lAqdxbyICnB1Nsm5QGGm6Pxd8lEbQ9ZSwFjgqApjT2zVhuaaUC7jdlEP1H5snt9n +8FQR06lrzGyW04ud9pd6MXJup1oghAlvnzXioAH2Az0IXcHvqUGZQattFv27OXqj +QZ6ayNO119SNscvC6Qe9GLlbBEHDQWKPiftnS2Mh6Do= +-----END CERTIFICATE-----` + func Test_ReSyncTests(t *testing.T) { const target = "localBufCon" @@ -76,7 +108,7 @@ func Test_ReSyncTests(t *testing.T) { c := syncv1grpc.NewFlagSyncServiceClient(dial) grpcSync := Sync{ - Target: target, + URI: target, ProviderID: "", Logger: logger.NewLogger(nil, false), client: c, @@ -110,32 +142,60 @@ func Test_ReSyncTests(t *testing.T) { } } -func TestUrlToGRPCTarget(t *testing.T) { +func TestSourceToGRPCTarget(t *testing.T) { tests := []struct { name string url string want string + ok bool }{ { name: "With Prefix", url: "grpc://test.com/endpoint", want: "test.com/endpoint", + ok: true, }, { - name: "Without Prefix", - url: "test.com/endpoint", + name: "With secure Prefix", + url: "grpcs://test.com/endpoint", want: "test.com/endpoint", + ok: true, }, { - name: "Empty is empty", + name: "Empty is error", url: "", want: "", + ok: false, + }, + { + name: "Invalid is error", + url: "https://test.com/endpoint", + want: "", + ok: false, + }, + { + name: "Prefix is not enough I", + url: Prefix, + want: "", + ok: false, + }, + { + name: "Prefix is not enough II", + url: PrefixSecure, + want: "", + ok: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := URLToGRPCTarget(tt.url); got != tt.want { - t.Errorf("URLToGRPCTarget() = %v, want %v", got, tt.want) + got, ok := sourceToGRPCTarget(tt.url) + + if tt.ok != ok { + t.Errorf("URLToGRPCTarget() returned = %v, want %v", ok, tt.ok) + } + + if got != tt.want { + t.Errorf("URLToGRPCTarget() returned = %v, want %v", got, tt.want) } }) } @@ -143,23 +203,25 @@ func TestUrlToGRPCTarget(t *testing.T) { func TestSync_BasicFlagSyncStates(t *testing.T) { grpcSyncImpl := Sync{ - Target: "grpc://test", + URI: "grpc://test", ProviderID: "", Logger: logger.NewLogger(nil, false), } tests := []struct { name string - stream syncv1grpc.FlagSyncService_SyncFlagsClient + stream syncv1grpc.FlagSyncServiceClient want sync.Type ready bool }{ { name: "State All maps to Sync All", - stream: &SimpleRecvMock{ - mockResponse: v1.SyncFlagsResponse{ - FlagConfiguration: "{}", - State: v1.SyncState_SYNC_STATE_ALL, + stream: &MockServiceClient{ + mockStream: SimpleRecvMock{ + mockResponse: v1.SyncFlagsResponse{ + FlagConfiguration: "{}", + State: v1.SyncState_SYNC_STATE_ALL, + }, }, }, want: sync.ALL, @@ -167,10 +229,12 @@ func TestSync_BasicFlagSyncStates(t *testing.T) { }, { name: "State Add maps to Sync Add", - stream: &SimpleRecvMock{ - mockResponse: v1.SyncFlagsResponse{ - FlagConfiguration: "{}", - State: v1.SyncState_SYNC_STATE_ADD, + stream: &MockServiceClient{ + mockStream: SimpleRecvMock{ + mockResponse: v1.SyncFlagsResponse{ + FlagConfiguration: "{}", + State: v1.SyncState_SYNC_STATE_ADD, + }, }, }, want: sync.ADD, @@ -178,10 +242,12 @@ func TestSync_BasicFlagSyncStates(t *testing.T) { }, { name: "State Update maps to Sync Update", - stream: &SimpleRecvMock{ - mockResponse: v1.SyncFlagsResponse{ - FlagConfiguration: "{}", - State: v1.SyncState_SYNC_STATE_UPDATE, + stream: &MockServiceClient{ + mockStream: SimpleRecvMock{ + mockResponse: v1.SyncFlagsResponse{ + FlagConfiguration: "{}", + State: v1.SyncState_SYNC_STATE_UPDATE, + }, }, }, want: sync.UPDATE, @@ -189,10 +255,12 @@ func TestSync_BasicFlagSyncStates(t *testing.T) { }, { name: "State Delete maps to Sync Delete", - stream: &SimpleRecvMock{ - mockResponse: v1.SyncFlagsResponse{ - FlagConfiguration: "{}", - State: v1.SyncState_SYNC_STATE_DELETE, + stream: &MockServiceClient{ + mockStream: SimpleRecvMock{ + mockResponse: v1.SyncFlagsResponse{ + FlagConfiguration: "{}", + State: v1.SyncState_SYNC_STATE_DELETE, + }, }, }, want: sync.DELETE, @@ -205,12 +273,13 @@ func TestSync_BasicFlagSyncStates(t *testing.T) { syncChan := make(chan sync.DataSync) go func() { - grpcSyncImpl.syncClient = test.stream + grpcSyncImpl.client = test.stream err := grpcSyncImpl.Sync(context.TODO(), syncChan) if err != nil { t.Errorf("Error handling flag sync: %s", err.Error()) } }() + data := <-syncChan if grpcSyncImpl.IsReady() != test.ready { @@ -329,12 +398,6 @@ func Test_StreamListener(t *testing.T) { // start server go serve(&bufServer) - grpcSync := Sync{ - Target: target, - ProviderID: "", - Logger: logger.NewLogger(nil, false), - } - // initialize client dial, err := grpc.Dial(target, grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { @@ -346,16 +409,19 @@ func Test_StreamListener(t *testing.T) { } serviceClient := syncv1grpc.NewFlagSyncServiceClient(dial) - syncClient, err := serviceClient.SyncFlags(context.Background(), &v1.SyncFlagsRequest{ProviderId: grpcSync.ProviderID}) - if err != nil { - t.Errorf("Error opening client stream: %s", err.Error()) + + grpcSync := Sync{ + URI: target, + ProviderID: "", + Logger: logger.NewLogger(nil, false), + + client: serviceClient, } syncChan := make(chan sync.DataSync, 1) // listen to stream go func() { - grpcSync.syncClient = syncClient err := grpcSync.Sync(context.TODO(), syncChan) if err != nil { // must ignore EOF as this is returned for stream end @@ -384,8 +450,275 @@ func Test_StreamListener(t *testing.T) { } } +func Test_BuildTCredentials(t *testing.T) { + // "insecure" is a hardcoded term at insecure.NewCredentials + const insecure = "insecure" + // "tls" is a hardcoded term at tlsCreds.Info + const tls = "tls" + // local test file with valid certificate + const validCertFile = "valid.cert" + // local test file with invalid certificate + const invalidCertFile = "invalid.cert" + + // init cert files for tests & cleanup with a deffer + err := os.WriteFile(validCertFile, []byte(sampleCert), 0o600) + if err != nil { + t.Errorf("error creating valid certificate file: %s", err) + } + + err = os.WriteFile(invalidCertFile, []byte("--certificate--"), 0o600) + if err != nil { + t.Errorf("error creating invalid certificate file: %s", err) + } + + defer func() { + errV := os.Remove(validCertFile) + errI := os.Remove(invalidCertFile) + if errV != nil || errI != nil { + t.Errorf("error removing cerificate files: %v, %v", errV, errI) + } + }() + + tests := []struct { + name string + source string + certPath string + expectSecProto string + error bool + }{ + { + name: "Insecure source results in insecure connection", + source: Prefix + "some.domain", + certPath: "", + expectSecProto: insecure, + }, + { + name: "Secure source results in secure connection", + source: PrefixSecure + "some.domain", + certPath: validCertFile, + expectSecProto: tls, + }, + { + name: "Secure source with no certificate results in a secure connection", + source: PrefixSecure + "some.domain", + expectSecProto: tls, + }, + { + name: "Invalid cert path results in an error", + source: PrefixSecure + "some.domain", + certPath: "invalid/path", + error: true, + }, + { + name: "Invalid certificate results in an error", + source: PrefixSecure + "some.domain", + certPath: invalidCertFile, + error: true, + }, + { + name: "Invalid prefix results in an error", + source: "http://some.domain", + error: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tCred, err := buildTransportCredentials(test.source, test.certPath) + + if test.error { + if err == nil { + t.Errorf("test expected non error execution. But resulted in an error: %s", err.Error()) + } + + // Test expected an error. Nothing to validate further + return + } + + // check for errors to be certain + if err != nil { + t.Errorf("unexpected error: %s", err.Error()) + } + + protoc := tCred.Info().SecurityProtocol + if protoc != test.expectSecProto { + t.Errorf("buildTransportCredentials() returned protocol= %v, want %v", protoc, test.expectSecProto) + } + }) + } +} + +// Test_ConnectWithRetry is an attempt to validate grpc.connectWithRetry behavior +func Test_ConnectWithRetry(t *testing.T) { + target := "grpc://local" + bufListener := bufconn.Listen(1) + // buffer based server. response ignored purposefully + bServer := bufferedServer{listener: bufListener} + + // generate a client connection backed with bufconn + clientConn, err := grpc.Dial(target, + grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { + return bufListener.DialContext(ctx) + }), + grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Errorf("error initiating the connection: %s", err.Error()) + } + + // minimal sync provider + grpcSync := Sync{ + Logger: logger.NewLogger(nil, false), + client: syncv1grpc.NewFlagSyncServiceClient(clientConn), + } + + // test must complete within an acceptable timeframe + tCtx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Second) + defer cancelFunc() + + // channel for connection + clientChan := make(chan syncv1grpc.FlagSyncService_SyncFlagsClient) + + // start connection retry attempts + go func() { + client, ok := grpcSync.connectWithRetry(tCtx) + if !ok { + clientChan <- nil + } + + clientChan <- client + }() + + // Wait for retries in the background + select { + case <-time.After(2 * time.Second): + break + case <-tCtx.Done(): + // We should not reach this with correct test setup, but in case we do + cancelFunc() + t.Errorf("timeout occurred while waiting for conditions to fulfil") + } + + // start the server - fulfill connection after the wait + go serve(&bServer) + + // Wait for client connection + var con syncv1grpc.FlagSyncService_SyncFlagsClient + + select { + case con = <-clientChan: + break + case <-tCtx.Done(): + cancelFunc() + t.Errorf("timeout occurred while waiting for conditions to fulfil") + } + + if con == nil { + t.Errorf("received a nil value when expecting a non-nil return") + } +} + +// Test_SyncRetry validates sync and retry attempts +func Test_SyncRetry(t *testing.T) { + // Setup + target := "grpc://local" + bufListener := bufconn.Listen(1) + + expectType := sync.ALL + + // buffer based server. response ignored purposefully + bServer := bufferedServer{listener: bufListener, mockResponses: []serverPayload{ + { + flags: "{}", + state: v1.SyncState_SYNC_STATE_ALL, + }, + }} + + // generate a client connection backed by bufListener + clientConn, err := grpc.Dial(target, + grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { + return bufListener.DialContext(ctx) + }), + grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Errorf("error initiating the connection: %s", err.Error()) + } + + // minimal sync provider + grpcSync := Sync{ + Logger: logger.NewLogger(nil, false), + client: syncv1grpc.NewFlagSyncServiceClient(clientConn), + } + + // channel for data sync + syncChan := make(chan sync.DataSync, 1) + + // Testing + + // Initial mock server - start mock server backed by a error group. Allow connection and disconnect with a timeout + tCtx, cancelFunc := context.WithTimeout(context.Background(), 2*time.Second) + defer cancelFunc() + + group, _ := errgroup.WithContext(tCtx) + group.Go(func() error { + serve(&bServer) + return nil + }) + + // Start Sync for grpc streaming + go func() { + err := grpcSync.Sync(context.Background(), syncChan) + if err != nil { + t.Errorf("sync start error: %s", err.Error()) + } + }() + + // Check for timeout (not ideal) or data sync (ideal) and cancel the context + select { + case <-tCtx.Done(): + t.Errorf("timeout waiting for conditions to fulfil") + break + case data := <-syncChan: + if data.Type != expectType { + t.Errorf("sync start error: %s", err.Error()) + } + } + + // cancel make error group to complete, making background mock server to exit + cancelFunc() + + // Follow up mock server start - start mock server after initial shutdown + tCtx, cancelFunc = context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFunc() + + // Restart the server + go serve(&bServer) + + // validate connection re-establishment + select { + case <-tCtx.Done(): + cancelFunc() + t.Error("timeout waiting for conditions to fulfil") + case rsp := <-syncChan: + if rsp.Type != expectType { + t.Errorf("expected response: %s, but got: %s", expectType, rsp.Type) + } + } +} + // Mock implementations +type MockServiceClient struct { + syncv1grpc.FlagSyncServiceClient + + mockStream SimpleRecvMock +} + +func (c *MockServiceClient) SyncFlags(_ context.Context, + _ *v1.SyncFlagsRequest, _ ...grpc.CallOption, +) (syncv1grpc.FlagSyncService_SyncFlagsClient, error) { + return &c.mockStream, nil +} + type SimpleRecvMock struct { grpc.ClientStream mockResponse v1.SyncFlagsResponse @@ -395,7 +728,7 @@ func (s *SimpleRecvMock) Recv() (*v1.SyncFlagsResponse, error) { return &s.mockResponse, nil } -// serve serves a bufferedServer +// serve serves a bufferedServer. This is a blocking call func serve(bServer *bufferedServer) { server := grpc.NewServer() diff --git a/pkg/sync/isync.go b/pkg/sync/isync.go index 82ca6ee27..115011da2 100644 --- a/pkg/sync/isync.go +++ b/pkg/sync/isync.go @@ -63,4 +63,5 @@ type SourceConfig struct { Provider string `json:"provider"` BearerToken string `json:"bearerToken,omitempty"` + CertPath string `json:"certPath,omitempty"` }