diff --git a/pkg/activator/handler/handler.go b/pkg/activator/handler/handler.go index 503754f3b2dc..83fbece9c24b 100644 --- a/pkg/activator/handler/handler.go +++ b/pkg/activator/handler/handler.go @@ -90,7 +90,7 @@ func New(l *zap.SugaredLogger, r activator.StatsReporter, t *activator.Throttler } } -func withOrigProto(or *http.Request) prober.ProbeOption { +func withOrigProto(or *http.Request) prober.Preparer { return func(r *http.Request) *http.Request { r.Proto = or.Proto r.ProtoMajor = or.ProtoMajor @@ -113,7 +113,13 @@ func (a *activationHandler) probeEndpoint(logger *zap.SugaredLogger, r *http.Req err := wait.PollImmediate(100*time.Millisecond, a.probeTimeout, func() (bool, error) { attempts++ - ret, err := prober.Do(reqCtx, a.probeTransportFactory(), target.String(), queue.Name, withOrigProto(r)) + ret, err := prober.Do( + reqCtx, + a.probeTransportFactory(), + target.String(), + prober.WithHeader(network.ProbeHeaderName, queue.Name), + prober.ExpectsBody(queue.Name), + withOrigProto(r)) if err != nil { logger.Warnw("Pod probe failed", zap.Error(err)) return false, nil diff --git a/pkg/network/prober/prober.go b/pkg/network/prober/prober.go index 7e1755c8a50b..770d90259b15 100644 --- a/pkg/network/prober/prober.go +++ b/pkg/network/prober/prober.go @@ -26,29 +26,45 @@ import ( "github.com/pkg/errors" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/wait" - - "github.com/knative/serving/pkg/network" ) // TransportFactory is a function which returns an HTTP transport. type TransportFactory func() http.RoundTripper -// ProbeOption is a way for caller to modify the HTTP request before it goes out. -type ProbeOption func(r *http.Request) *http.Request +// Preparer is a way for the caller to modify the HTTP request before it goes out. +type Preparer func(r *http.Request) *http.Request + +// Verifier is a way for the caller to validate the HTTP response after it comes back. +type Verifier func(r *http.Response, b []byte) (bool, error) + +// WithHeader sets a header in the probe request. +func WithHeader(name, value string) Preparer { + return func(r *http.Request) *http.Request { + r.Header.Set(name, value) + return r + } +} + +// ExpectsBody validates that the body of the probe response matches the provided string. +func ExpectsBody(body string) Verifier { + return func(r *http.Response, b []byte) (bool, error) { + return string(b) == body, nil + } +} // Do sends a single probe to given target, e.g. `http://revision.default.svc.cluster.local:81`. -// headerValue is the value for the `k-network-probe` header. // Do returns whether the probe was successful or not, or there was an error probing. -func Do(ctx context.Context, transport http.RoundTripper, target, headerValue string, pos ...ProbeOption) (bool, error) { +func Do(ctx context.Context, transport http.RoundTripper, target string, ops ...interface{}) (bool, error) { req, err := http.NewRequest(http.MethodGet, target, nil) if err != nil { return false, errors.Wrapf(err, "%s is not a valid URL", target) } - for _, po := range pos { - req = po(req) + for _, op := range ops { + if po, ok := op.(Preparer); ok { + req = po(req) + } } - req.Header.Set(network.ProbeHeaderName, headerValue) req = req.WithContext(ctx) resp, err := transport.RoundTrip(req) if err != nil { @@ -59,7 +75,16 @@ func Do(ctx context.Context, transport http.RoundTripper, target, headerValue st if err != nil { return false, errors.Wrap(err, "error reading body") } - return resp.StatusCode == http.StatusOK && string(body) == headerValue, nil + + for _, op := range ops { + if vo, ok := op.(Verifier); ok { + ok, err := vo(resp, body) + if err != nil || !ok { + return false, err + } + } + } + return resp.StatusCode == http.StatusOK, nil } // Done is a callback that is executed when the async probe has finished. @@ -101,19 +126,19 @@ func New(cb Done, transportFactory TransportFactory) *Manager { // Otherwise Offer starts a goroutine that periodically executes // `Do`, until timeout is reached, the probe succeeds, or fails with an error. // In the end the callback is invoked with the provided `arg` and probing results. -func (m *Manager) Offer(ctx context.Context, target, headerValue string, arg interface{}, period, timeout time.Duration) bool { +func (m *Manager) Offer(ctx context.Context, target string, arg interface{}, period, timeout time.Duration, ops ...interface{}) bool { m.mu.Lock() defer m.mu.Unlock() if m.keys.Has(target) { return false } m.keys.Insert(target) - m.doAsync(ctx, m.transportFactory, target, headerValue, arg, period, timeout) + m.doAsync(ctx, m.transportFactory, target, arg, period, timeout, ops...) return true } // doAsync starts a go routine that probes the target with given period. -func (m *Manager) doAsync(ctx context.Context, transportFactory TransportFactory, target, headerValue string, arg interface{}, period, timeout time.Duration) { +func (m *Manager) doAsync(ctx context.Context, transportFactory TransportFactory, target string, arg interface{}, period, timeout time.Duration, ops ...interface{}) { go func() { defer func() { m.mu.Lock() @@ -124,8 +149,9 @@ func (m *Manager) doAsync(ctx context.Context, transportFactory TransportFactory result bool err error ) + err = wait.PollImmediate(period, timeout, func() (bool, error) { - result, err = Do(ctx, transportFactory(), target, headerValue) + result, err = Do(ctx, transportFactory(), target, ops...) return result, err }) m.cb(arg, result, err) diff --git a/pkg/network/prober/prober_test.go b/pkg/network/prober/prober_test.go index 2b26eab92383..fc484ec5c909 100644 --- a/pkg/network/prober/prober_test.go +++ b/pkg/network/prober/prober_test.go @@ -70,7 +70,7 @@ func TestDoServing(t *testing.T) { }} for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := Do(context.Background(), network.NewAutoTransport(), ts.URL, test.headerValue) + got, err := Do(context.Background(), network.NewAutoTransport(), ts.URL, WithHeader(network.ProbeHeaderName, test.headerValue), ExpectsBody(test.headerValue)) if want := test.want; got != want { t.Errorf("Got = %v, want: %v", got, want) } @@ -82,7 +82,7 @@ func TestDoServing(t *testing.T) { } func TestBlackHole(t *testing.T) { - got, err := Do(context.Background(), network.NewAutoTransport(), "http://gone.fishing.svc.custer.local:8080", systemName) + got, err := Do(context.Background(), network.NewAutoTransport(), "http://gone.fishing.svc.custer.local:8080") if want := false; got != want { t.Errorf("Got = %v, want: %v", got, want) } @@ -92,7 +92,7 @@ func TestBlackHole(t *testing.T) { } func TestBadURL(t *testing.T) { - _, err := Do(context.Background(), network.NewAutoTransport(), ":foo", systemName) + _, err := Do(context.Background(), network.NewAutoTransport(), ":foo") if err == nil { t.Error("Do did not return an error") } @@ -150,7 +150,7 @@ func TestDoAsync(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { m := New(test.cb, network.NewAutoTransport) - m.Offer(context.Background(), ts.URL, test.headerValue, test.name, 50*time.Millisecond, 2*time.Second) + m.Offer(context.Background(), ts.URL, test.name, 50*time.Millisecond, 2*time.Second, WithHeader(network.ProbeHeaderName, test.headerValue), ExpectsBody(test.headerValue)) <-wch }) } @@ -187,7 +187,7 @@ func TestDoAsyncRepeat(t *testing.T) { wch <- arg } m := New(cb, network.NewAutoTransport) - m.Offer(context.Background(), ts.URL, systemName, 42, 50*time.Millisecond, 3*time.Second) + m.Offer(context.Background(), ts.URL, 42, 50*time.Millisecond, 3*time.Second, WithHeader(network.ProbeHeaderName, systemName), ExpectsBody(systemName)) <-wch if got, want := c.calls, 3; got != want { t.Errorf("Probe invocation count = %d, want: %d", got, want) @@ -210,7 +210,7 @@ func TestDoAsyncTimeout(t *testing.T) { wch <- arg } m := New(cb, network.NewAutoTransport) - m.Offer(context.Background(), ts.URL, systemName, 2009, 10*time.Millisecond, 200*time.Millisecond) + m.Offer(context.Background(), ts.URL, 2009, 10*time.Millisecond, 200*time.Millisecond) <-wch } @@ -225,10 +225,10 @@ func TestAsyncMultiple(t *testing.T) { wch <- 2006 } m := New(cb, network.NewAutoTransport) - if !m.Offer(context.Background(), ts.URL, systemName, 1984, 100*time.Millisecond, 1*time.Second) { + if !m.Offer(context.Background(), ts.URL, 1984, 100*time.Millisecond, 1*time.Second) { t.Error("First call to offer returned false") } - if m.Offer(context.Background(), ts.URL, systemName, 1982, 100*time.Millisecond, 1*time.Second) { + if m.Offer(context.Background(), ts.URL, 1982, 100*time.Millisecond, 1*time.Second) { t.Error("Second call to offer returned true") } if got, want := m.len(), 1; got != want { diff --git a/pkg/reconciler/autoscaling/kpa/scaler.go b/pkg/reconciler/autoscaling/kpa/scaler.go index e5b0491402c1..7f1b2ad47e50 100644 --- a/pkg/reconciler/autoscaling/kpa/scaler.go +++ b/pkg/reconciler/autoscaling/kpa/scaler.go @@ -53,9 +53,14 @@ const ( reenqeuePeriod = 1 * time.Second ) +var probeOptions = []interface{} { + prober.WithHeader(network.ProbeHeaderName, activator.Name), + prober.ExpectsBody(activator.Name), +} + // for mocking in tests type asyncProber interface { - Offer(context.Context, string, string, interface{}, time.Duration, time.Duration) bool + Offer(context.Context, string, interface{}, time.Duration, time.Duration, ...interface{}) bool } // scaler scales the target of a kpa-class PA up or down including scaling to zero. @@ -112,7 +117,7 @@ func activatorProbe(pa *pav1alpha1.PodAutoscaler, transport http.RoundTripper) ( if pa.Status.ServiceName == "" { return false, nil } - return prober.Do(context.Background(), transport, paToProbeTarget(pa), activator.Name) + return prober.Do(context.Background(), transport, paToProbeTarget(pa), probeOptions...) } // pre: 0 <= min <= max && 0 <= x @@ -172,7 +177,7 @@ func (ks *scaler) handleScaleToZero(pa *pav1alpha1.PodAutoscaler, desiredScale i // Otherwise (any prober failure) start the async probe. ks.logger.Infof("%s is not yet backed by activator, cannot scale to zero", pa.Name) - if !ks.probeManager.Offer(context.Background(), paToProbeTarget(pa), activator.Name, pa, probePeriod, probeTimeout) { + if !ks.probeManager.Offer(context.Background(), paToProbeTarget(pa), pa, probePeriod, probeTimeout, probeOptions...) { ks.logger.Infof("Probe for %s is already in flight", pa.Name) } return desiredScale, false diff --git a/pkg/reconciler/autoscaling/kpa/scaler_test.go b/pkg/reconciler/autoscaling/kpa/scaler_test.go index 76cede044985..f65baba14ef8 100644 --- a/pkg/reconciler/autoscaling/kpa/scaler_test.go +++ b/pkg/reconciler/autoscaling/kpa/scaler_test.go @@ -583,7 +583,7 @@ type countingProber struct { count int } -func (c *countingProber) Offer(ctx context.Context, target, headerValue string, arg interface{}, period, timeout time.Duration) bool { +func (c *countingProber) Offer(ctx context.Context, target string, arg interface{}, period, timeout time.Duration, ops ...interface{}) bool { c.count++ return true }