diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java index 1c440bf0bde3..52c2af5fa4ad 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java @@ -30,8 +30,8 @@ import software.amazon.awssdk.core.internal.async.ByteArrayAsyncRequestBody; import software.amazon.awssdk.core.internal.async.FileAsyncRequestBody; import software.amazon.awssdk.core.internal.async.InputStreamWithExecutorAsyncRequestBody; +import software.amazon.awssdk.core.internal.async.SimpleAsyncRequestBody; import software.amazon.awssdk.core.internal.util.Mimetype; -import software.amazon.awssdk.utils.BinaryUtils; /** * Interface to allow non-blocking streaming of request content. This follows the reactive streams pattern where @@ -124,11 +124,11 @@ static AsyncRequestBody fromFile(File file) { * @param string The string to provide. * @param cs The {@link Charset} to use. * @return Implementation of {@link AsyncRequestBody} that uses the specified string. - * @see ByteArrayAsyncRequestBody + * @see SimpleAsyncRequestBody */ static AsyncRequestBody fromString(String string, Charset cs) { - return new ByteArrayAsyncRequestBody(string.getBytes(cs), - Mimetype.MIMETYPE_TEXT_PLAIN + "; charset=" + cs.name()); + return SimpleAsyncRequestBody.of(Mimetype.MIMETYPE_TEXT_PLAIN + "; charset=" + cs.name(), + string.getBytes(cs)); } /** @@ -143,25 +143,33 @@ static AsyncRequestBody fromString(String string) { } /** - * Creates a {@link AsyncRequestBody} from a byte array. The contents of the byte array are copied so modifications to the - * original byte array are not reflected in the {@link AsyncRequestBody}. + * Creates a {@link AsyncRequestBody} from a byte array. * * @param bytes The bytes to send to the service. * @return AsyncRequestBody instance. */ static AsyncRequestBody fromBytes(byte[] bytes) { - return new ByteArrayAsyncRequestBody(bytes, Mimetype.MIMETYPE_OCTET_STREAM); + return SimpleAsyncRequestBody.of(bytes); } /** - * Creates a {@link AsyncRequestBody} from a {@link ByteBuffer}. Buffer contents are copied so any modifications - * made to the original {@link ByteBuffer} are not reflected in the {@link AsyncRequestBody}. + * Creates a {@link AsyncRequestBody} from a {@link ByteBuffer}. * * @param byteBuffer ByteBuffer to send to the service. * @return AsyncRequestBody instance. */ static AsyncRequestBody fromByteBuffer(ByteBuffer byteBuffer) { - return fromBytes(BinaryUtils.copyAllBytesFrom(byteBuffer)); + return SimpleAsyncRequestBody.of(null, byteBuffer); + } + + /** + * Creates a {@link AsyncRequestBody} from an array of {@link ByteBuffer}. + * + * @param byteBuffers ByteBuffer[] to send to the service. + * @return AsyncRequestBody instance. + */ + static AsyncRequestBody fromByteBuffers(ByteBuffer... byteBuffers) { + return SimpleAsyncRequestBody.of(null, byteBuffers); } /** diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteArrayAsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteArrayAsyncRequestBody.java index 29205479b798..42f0f8a62e42 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteArrayAsyncRequestBody.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteArrayAsyncRequestBody.java @@ -32,6 +32,7 @@ * @see AsyncRequestBody#fromString(String) */ @SdkInternalApi +@Deprecated public final class ByteArrayAsyncRequestBody implements AsyncRequestBody { private static final Logger log = Logger.loggerFor(ByteArrayAsyncRequestBody.class); diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SimpleAsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SimpleAsyncRequestBody.java new file mode 100644 index 000000000000..e1f8f4d437c3 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SimpleAsyncRequestBody.java @@ -0,0 +1,136 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.internal.util.Mimetype; +import software.amazon.awssdk.utils.BinaryUtils; +import software.amazon.awssdk.utils.Logger; + +/** + * An implementation of {@link AsyncRequestBody} for providing data from the supplied {@link ByteBuffer} arrau. This is created + * using static + * methods on {@link AsyncRequestBody} + * + * @see AsyncRequestBody#fromBytes(byte[]) + * @see AsyncRequestBody#fromByteBuffer(ByteBuffer) + * @see AsyncRequestBody#fromString(String) + */ +@SdkInternalApi +public final class SimpleAsyncRequestBody implements AsyncRequestBody { + private static final Logger log = Logger.loggerFor(SimpleAsyncRequestBody.class); + + private final String mimetype; + private final Long length; + private final ByteBuffer[] buffers; + + private SimpleAsyncRequestBody(String mimetype, Long length, ByteBuffer... buffers) { + this.mimetype = mimetype; + this.length = length; + this.buffers = buffers; + } + + @Override + public Optional contentLength() { + return Optional.ofNullable(length); + } + + @Override + public String contentType() { + return mimetype; + } + + @Override + public void subscribe(Subscriber s) { + // As per rule 1.9 we must throw NullPointerException if the subscriber parameter is null + if (s == null) { + throw new NullPointerException("Subscription MUST NOT be null."); + } + + // As per 2.13, this method must return normally (i.e. not throw). + try { + s.onSubscribe( + new Subscription() { + private final AtomicInteger index = new AtomicInteger(0); + private final AtomicBoolean completed = new AtomicBoolean(false); + + @Override + public void request(long n) { + if (completed.get()) { + return; + } + + if (n > 0) { + int i = index.getAndIncrement(); + + if (i >= buffers.length) { + return; + } + + long remaining = n; + + do { + ByteBuffer buffer = buffers[i]; + if (!buffer.hasArray()) { + buffer = ByteBuffer.wrap(BinaryUtils.copyBytesFrom(buffer)); + } + s.onNext(buffer); + remaining--; + } while (remaining > 0 && (i = index.getAndIncrement()) < buffers.length); + + if (i >= buffers.length - 1 && completed.compareAndSet(false, true)) { + s.onComplete(); + } + } else { + s.onError(new IllegalArgumentException("ยง3.9: non-positive requests are not allowed!")); + } + } + + @Override + public void cancel() { + completed.set(true); + } + } + ); + } catch (Throwable ex) { + log.error(() -> s + " violated the Reactive Streams rule 2.13 by throwing an exception from onSubscribe.", ex); + } + } + + public static SimpleAsyncRequestBody of(Long length, ByteBuffer... buffers) { + return new SimpleAsyncRequestBody(Mimetype.MIMETYPE_OCTET_STREAM, length, buffers); + } + + public static SimpleAsyncRequestBody of(String mimetype, Long length, ByteBuffer... buffers) { + return new SimpleAsyncRequestBody(mimetype, length, buffers); + } + + public static SimpleAsyncRequestBody of(byte[] bytes) { + return new SimpleAsyncRequestBody(Mimetype.MIMETYPE_OCTET_STREAM, (long) bytes.length, + ByteBuffer.wrap(bytes)); + } + + public static SimpleAsyncRequestBody of(String mimetype, byte[] bytes) { + return new SimpleAsyncRequestBody(mimetype, (long) bytes.length, ByteBuffer.wrap(bytes)); + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SimpleAsyncRequestBodyTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SimpleAsyncRequestBodyTest.java new file mode 100644 index 000000000000..daf2952d231d --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SimpleAsyncRequestBodyTest.java @@ -0,0 +1,171 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.utils.BinaryUtils; + +class SimpleAsyncRequestBodyTest { + + private static class TestSubscriber implements Subscriber { + private Subscription subscription; + private boolean onCompleteCalled = false; + private int callsToComplete = 0; + private final List publishedResults = Collections.synchronizedList(new ArrayList<>()); + + public void request(long n) { + subscription.request(n); + } + + @Override + public void onSubscribe(Subscription s) { + this.subscription = s; + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + publishedResults.add(byteBuffer); + } + + @Override + public void onError(Throwable throwable) { + throw new IllegalStateException(throwable); + } + + @Override + public void onComplete() { + onCompleteCalled = true; + callsToComplete++; + } + } + + @Test + public void subscriberIsMarkedAsCompleted() { + AsyncRequestBody requestBody = SimpleAsyncRequestBody.of("Hello World!".getBytes(StandardCharsets.UTF_8)); + + TestSubscriber subscriber = new TestSubscriber(); + requestBody.subscribe(subscriber); + subscriber.request(1); + + assertTrue(subscriber.onCompleteCalled); + assertEquals(1, subscriber.publishedResults.size()); + } + + @Test + public void subscriberIsMarkedAsCompletedWhenARequestIsMadeForMoreBuffersThanAreAvailable() { + AsyncRequestBody requestBody = SimpleAsyncRequestBody.of("Hello World!".getBytes(StandardCharsets.UTF_8)); + + TestSubscriber subscriber = new TestSubscriber(); + requestBody.subscribe(subscriber); + subscriber.request(2); + + assertTrue(subscriber.onCompleteCalled); + assertEquals(1, subscriber.publishedResults.size()); + } + + @Test + public void subscriberIsThreadSafeAndMarkedAsCompletedExactlyOnce() throws InterruptedException { + int numBuffers = 100; + AsyncRequestBody requestBody = SimpleAsyncRequestBody.of(null, IntStream.range(0, numBuffers) + .mapToObj(i -> ByteBuffer.wrap(new byte[1])) + .toArray(ByteBuffer[]::new)); + + TestSubscriber subscriber = new TestSubscriber(); + requestBody.subscribe(subscriber); + + int parallelism = 8; + ExecutorService executorService = Executors.newFixedThreadPool(parallelism); + for (int i = 0; i < parallelism; i++) { + executorService.submit(() -> { + for (int j = 0; j < numBuffers; j++) { + subscriber.request(2); + } + }); + } + executorService.shutdown(); + executorService.awaitTermination(1, TimeUnit.MINUTES); + + assertTrue(subscriber.onCompleteCalled); + assertEquals(1, subscriber.callsToComplete); + assertEquals(numBuffers, subscriber.publishedResults.size()); + } + + @Test + public void subscriberIsNotMarkedAsCompletedWhenThereAreRemainingBuffersToPublish() { + byte[] helloWorld = "Hello World!".getBytes(StandardCharsets.UTF_8); + byte[] goodbyeWorld = "Goodbye World!".getBytes(StandardCharsets.UTF_8); + AsyncRequestBody requestBody = SimpleAsyncRequestBody.of((long) (helloWorld.length + goodbyeWorld.length), + ByteBuffer.wrap(helloWorld), + ByteBuffer.wrap(goodbyeWorld)); + + TestSubscriber subscriber = new TestSubscriber(); + requestBody.subscribe(subscriber); + subscriber.request(1); + + assertFalse(subscriber.onCompleteCalled); + assertEquals(1, subscriber.publishedResults.size()); + } + + @Test + public void subscriberReceivesAllBuffers() { + byte[] helloWorld = "Hello World!".getBytes(StandardCharsets.UTF_8); + byte[] goodbyeWorld = "Goodbye World!".getBytes(StandardCharsets.UTF_8); + + AsyncRequestBody requestBody = SimpleAsyncRequestBody.of((long) (helloWorld.length + goodbyeWorld.length), + ByteBuffer.wrap(helloWorld), + ByteBuffer.wrap(goodbyeWorld)); + + TestSubscriber subscriber = new TestSubscriber(); + requestBody.subscribe(subscriber); + subscriber.request(2); + + assertEquals(2, subscriber.publishedResults.size()); + assertTrue(subscriber.onCompleteCalled); + assertArrayEquals(helloWorld, BinaryUtils.copyAllBytesFrom(subscriber.publishedResults.get(0))); + assertArrayEquals(goodbyeWorld, BinaryUtils.copyAllBytesFrom(subscriber.publishedResults.get(1))); + } + + @Test + public void canceledSubscriberDoesNotReturnNewResults() { + AsyncRequestBody requestBody = SimpleAsyncRequestBody.of(null, ByteBuffer.wrap(new byte[0])); + + TestSubscriber subscriber = new TestSubscriber(); + requestBody.subscribe(subscriber); + + subscriber.subscription.cancel(); + subscriber.request(1); + + assertTrue(subscriber.publishedResults.isEmpty()); + } + +} \ No newline at end of file