diff --git a/share/getters/cascade.go b/share/getters/cascade.go index d65b902e81..1a0d8fb274 100644 --- a/share/getters/cascade.go +++ b/share/getters/cascade.go @@ -122,16 +122,18 @@ func cascadeGetters[V any]( getCtx, cancel := ctxWithSplitTimeout(ctx, len(getters)-i, 0) val, getErr := get(getCtx, getter) cancel() - if getErr == nil { - return val, nil + if getErr == nil || errors.Is(getErr, share.ErrNamespaceNotFound) { + return val, getErr } - if errors.Is(share.ErrNamespaceNotFound, getErr) { - return zero, getErr + + if errors.Is(getErr, errOperationNotSupported) { + continue } - if !errors.Is(getErr, errOperationNotSupported) { - err = errors.Join(err, getErr) - span.RecordError(getErr, trace.WithAttributes(attribute.Int("getter_idx", i))) + err = errors.Join(err, getErr) + span.RecordError(getErr, trace.WithAttributes(attribute.Int("getter_idx", i))) + if ctx.Err() != nil { + return zero, err } } return zero, err diff --git a/share/getters/cascade_test.go b/share/getters/cascade_test.go index e3a324c5e9..d955c50682 100644 --- a/share/getters/cascade_test.go +++ b/share/getters/cascade_test.go @@ -52,6 +52,7 @@ func TestCascade(t *testing.T) { timeoutGetter := mocks.NewMockGetter(ctrl) immediateFailGetter := mocks.NewMockGetter(ctrl) successGetter := mocks.NewMockGetter(ctrl) + ctxGetter := mocks.NewMockGetter(ctrl) timeoutGetter.EXPECT().GetEDS(gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, _ *share.Root) (*rsmt2d.ExtendedDataSquare, error) { return nil, context.DeadlineExceeded @@ -60,6 +61,10 @@ func TestCascade(t *testing.T) { Return(nil, errors.New("second getter fails immediately")).AnyTimes() successGetter.EXPECT().GetEDS(gomock.Any(), gomock.Any()). Return(nil, nil).AnyTimes() + ctxGetter.EXPECT().GetEDS(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, _ *share.Root) (*rsmt2d.ExtendedDataSquare, error) { + return nil, ctx.Err() + }).AnyTimes() get := func(ctx context.Context, get share.Getter) (*rsmt2d.ExtendedDataSquare, error) { return get.GetEDS(ctx, nil) @@ -96,6 +101,15 @@ func TestCascade(t *testing.T) { assert.Equal(t, strings.Count(err.Error(), "\n"), 2) }) + t.Run("Context Canceled", func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + cancel() + getters := []share.Getter{ctxGetter, ctxGetter, ctxGetter} + _, err := cascadeGetters(ctx, getters, get) + assert.Error(t, err) + assert.Equal(t, strings.Count(err.Error(), "\n"), 0) + }) + t.Run("Single", func(t *testing.T) { getters := []share.Getter{successGetter} _, err := cascadeGetters(ctx, getters, get)