diff --git a/src/main/java/no/digipost/DiggBase.java b/src/main/java/no/digipost/DiggBase.java index fd48e69..17cbb33 100644 --- a/src/main/java/no/digipost/DiggBase.java +++ b/src/main/java/no/digipost/DiggBase.java @@ -21,12 +21,13 @@ import java.util.ArrayDeque; import java.util.Deque; -import java.util.Objects; import java.util.Optional; +import java.util.Spliterator; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Stream; +import java.util.stream.StreamSupport; import static java.util.Spliterator.ORDERED; import static java.util.Spliterators.spliterator; @@ -246,15 +247,47 @@ public static Stream forceOnAll(ThrowingConsumer Stream forceOnAll(ThrowingConsumer action, Stream instances) { - return instances.filter(Objects::nonNull).flatMap(instance -> { + return StreamSupport.stream(new FlatMapToExceptionSpliterator<>(action, instances.spliterator()), instances.isParallel()); + } + + private static final class FlatMapToExceptionSpliterator implements Spliterator { + + private final ThrowingConsumer action; + private final Spliterator wrappedSpliterator; + private final int characteristics; + + FlatMapToExceptionSpliterator(ThrowingConsumer action, Spliterator wrappedSpliterator) { + this.action = action; + this.wrappedSpliterator = wrappedSpliterator; + this.characteristics = wrappedSpliterator.characteristics() & ~(SIZED | SUBSIZED | SORTED); + } + + @Override + public boolean tryAdvance(Consumer exceptionConsumer) { try { - action.accept(instance); - } catch (Exception exception) { - return Stream.of(exception); + return wrappedSpliterator.tryAdvance(action.ifException(exceptionConsumer::accept)); + } catch (Exception e) { + exceptionConsumer.accept(e); + return true; } - return Stream.empty(); - }); - } + } + + @Override + public Spliterator trySplit() { + Spliterator triedSplit = wrappedSpliterator.trySplit(); + return triedSplit != null ? new FlatMapToExceptionSpliterator<>(action, triedSplit) : null; + } + + @Override + public long estimateSize() { + return Long.MAX_VALUE; + } + + @Override + public int characteristics() { + return characteristics; + } + }; /** diff --git a/src/test/java/no/digipost/DiggBaseTest.java b/src/test/java/no/digipost/DiggBaseTest.java index 10617e2..4229784 100644 --- a/src/test/java/no/digipost/DiggBaseTest.java +++ b/src/test/java/no/digipost/DiggBaseTest.java @@ -17,33 +17,46 @@ import no.digipost.util.AutoClosed; import no.digipost.util.ThrowingAutoClosed; -import org.hamcrest.Matchers; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.mockito.InOrder; import org.quicktheories.WithQuickTheories; import org.quicktheories.core.Gen; import org.quicktheories.dsl.TheoryBuilder; +import uk.co.probablyfine.matchers.StreamMatchers; import java.io.IOException; import java.io.InputStream; +import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Stream; import static java.util.Arrays.asList; import static java.util.stream.Collectors.toList; +import static java.util.stream.IntStream.iterate; +import static java.util.stream.IntStream.rangeClosed; import static java.util.stream.Stream.generate; import static no.digipost.DiggBase.autoClose; import static no.digipost.DiggBase.close; +import static no.digipost.DiggBase.forceOnAll; import static no.digipost.DiggBase.friendlyName; import static no.digipost.DiggBase.nonNull; import static no.digipost.DiggBase.throwingAutoClose; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; import static org.hamcrest.Matchers.sameInstance; +import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; @@ -52,8 +65,6 @@ import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; import static uk.co.probablyfine.matchers.Java8Matchers.where; -import static uk.co.probablyfine.matchers.StreamMatchers.contains; -import static uk.co.probablyfine.matchers.StreamMatchers.empty; public class DiggBaseTest implements WithQuickTheories { @@ -89,13 +100,20 @@ public void throwsExceptionWithDescriptionInMessage() { @Test public void extractOptionalValuesFromAnObject() { - assertThat(DiggBase.extractIfPresent("abc", s -> Optional.of(s.charAt(0)), s -> Optional.empty(), s -> Optional.of(s.charAt(2))), contains('a', 'c')); - assertThat(DiggBase.extractIfPresent("abc", s -> Optional.empty(), s -> Optional.empty()), empty()); + assertThat(DiggBase.extractIfPresent("abc", + s -> Optional.of(s.charAt(0)), + s -> Optional.empty(), + s -> Optional.of(s.charAt(2))), + StreamMatchers.contains('a', 'c')); + assertThat(DiggBase.extractIfPresent("abc", + s -> Optional.empty(), + s -> Optional.empty()), + StreamMatchers.empty()); } @Test public void extractValuesIncludesEverythingEvenNulls() { - assertThat(DiggBase.extract("abc", s -> s.charAt(0), s -> null), contains('a', null)); + assertThat(DiggBase.extract("abc", s -> s.charAt(0), s -> null), StreamMatchers.contains('a', null)); } @Test @@ -183,11 +201,84 @@ public void getAllExceptionsFromClosingSeveralAutoCloseables() throws Exception Stream closeExceptionsStream = close(generate(() -> closeable).limit(5).toArray(AutoCloseable[]::new)); verifyNoInteractions(closeable); List closeExceptions = closeExceptionsStream.collect(toList()); - assertThat(closeExceptions, Matchers.contains(asList(instanceOf(IOException.class), instanceOf(IllegalStateException.class)))); + assertThat(closeExceptions, contains(asList(instanceOf(IOException.class), instanceOf(IllegalStateException.class)))); verify(closeable, times(5)).close(); verifyNoMoreInteractions(closeable); } + @Nested + @Timeout(4) + class ForceOnAll { + + @Test + void runsOperationOnMultipleElements() { + List consumed = new ArrayList<>(); + List exceptions = forceOnAll(consumed::add, 1, 2, 3).collect(toList()); + assertThat(consumed, contains(1, 2, 3)); + assertThat(exceptions, empty()); + } + + @Test + void exceptionsFromOperationAreCollected() { + List onlyEvenNumbers = new ArrayList<>(); + List exceptions = forceOnAll(i -> { + if (i % 2 != 0) throw new IllegalArgumentException(i + " is odd!"); + onlyEvenNumbers.add(i); + }, 1, 2, 3, 4).collect(toList()); + assertThat(onlyEvenNumbers, contains(2, 4)); + assertThat(exceptions, contains( + where(Throwable::getMessage, is("1 is odd!")), + where(Throwable::getMessage, is("3 is odd!")))); + } + + @Test + void exceptionsFromTraversingStreamIsCollected() { + List consumed = new ArrayList<>(); + List exceptions = forceOnAll(consumed::add, + iterate(2, num -> num - 1).limit(5).mapToDouble(denominator -> 2 / denominator).boxed()) + .collect(toList()); + + assertAll( + () -> assertThat(exceptions, contains(isA(ArithmeticException.class))), + () -> assertThat(consumed, contains(1.0, 2.0, -2.0, -1.0))); + } + + @Test + void allElementsResolvedFromStreamException() { + List exceptions = forceOnAll(e -> fail("action should never be invoked"), + rangeClosed(1, 10).mapToObj(String::valueOf).map(s -> { throw new IllegalStateException(s); })) + .collect(toList()); + + assertThat(exceptions, hasSize(10)); + } + + @Test + void lastElementsResolvedFromStreamException() { + List consumed = new ArrayList<>(); + List exceptions = forceOnAll(consumed::add, + iterate(2, i -> i + -1).limit(3).mapToObj(denominator -> 2 / denominator)) + .collect(toList()); + assertAll( + () -> assertThat(exceptions, contains(isA(ArithmeticException.class))), + () -> assertThat(consumed, contains(1, 2))); + } + + @Test + void worksWithParalellStreams() { + AtomicLong successes = new AtomicLong(); + long failures = forceOnAll(__ -> successes.incrementAndGet(), + iterate(0, i -> i + 1).limit(100_000).parallel() + .map(i -> i % 4) // 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, ... + .map(denominator -> 4 / denominator) // Fails with div by zero 1/4 of the times + .boxed()) + .count(); + + assertThat("3 times as much successes as failures: " + successes + " / " + failures, + failures * 3, is(successes.get())); + } + + } + }