Skip to content

Commit

Permalink
Skip defensive copies and transforms in AsyncRequestBody
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenFlavin committed Apr 20, 2023
1 parent 46c2ae8 commit 1f41881
Show file tree
Hide file tree
Showing 4 changed files with 326 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
import software.amazon.awssdk.core.internal.async.ByteArrayAsyncRequestBody;
import software.amazon.awssdk.core.internal.async.FileAsyncRequestBody;
import software.amazon.awssdk.core.internal.async.InputStreamWithExecutorAsyncRequestBody;
import software.amazon.awssdk.core.internal.async.SimpleAsyncRequestBody;
import software.amazon.awssdk.core.internal.util.Mimetype;
import software.amazon.awssdk.utils.BinaryUtils;

/**
* Interface to allow non-blocking streaming of request content. This follows the reactive streams pattern where
Expand Down Expand Up @@ -124,11 +124,11 @@ static AsyncRequestBody fromFile(File file) {
* @param string The string to provide.
* @param cs The {@link Charset} to use.
* @return Implementation of {@link AsyncRequestBody} that uses the specified string.
* @see ByteArrayAsyncRequestBody
* @see SimpleAsyncRequestBody
*/
static AsyncRequestBody fromString(String string, Charset cs) {
return new ByteArrayAsyncRequestBody(string.getBytes(cs),
Mimetype.MIMETYPE_TEXT_PLAIN + "; charset=" + cs.name());
return SimpleAsyncRequestBody.of(Mimetype.MIMETYPE_TEXT_PLAIN + "; charset=" + cs.name(),
string.getBytes(cs));
}

/**
Expand All @@ -143,25 +143,33 @@ static AsyncRequestBody fromString(String string) {
}

/**
* Creates a {@link AsyncRequestBody} from a byte array. The contents of the byte array are copied so modifications to the
* original byte array are not reflected in the {@link AsyncRequestBody}.
* Creates a {@link AsyncRequestBody} from a byte array.
*
* @param bytes The bytes to send to the service.
* @return AsyncRequestBody instance.
*/
static AsyncRequestBody fromBytes(byte[] bytes) {
return new ByteArrayAsyncRequestBody(bytes, Mimetype.MIMETYPE_OCTET_STREAM);
return SimpleAsyncRequestBody.of(bytes);
}

/**
* Creates a {@link AsyncRequestBody} from a {@link ByteBuffer}. Buffer contents are copied so any modifications
* made to the original {@link ByteBuffer} are not reflected in the {@link AsyncRequestBody}.
* Creates a {@link AsyncRequestBody} from a {@link ByteBuffer}.
*
* @param byteBuffer ByteBuffer to send to the service.
* @return AsyncRequestBody instance.
*/
static AsyncRequestBody fromByteBuffer(ByteBuffer byteBuffer) {
return fromBytes(BinaryUtils.copyAllBytesFrom(byteBuffer));
return SimpleAsyncRequestBody.of(null, byteBuffer);
}

/**
* Creates a {@link AsyncRequestBody} from an array of {@link ByteBuffer}.
*
* @param byteBuffers ByteBuffer[] to send to the service.
* @return AsyncRequestBody instance.
*/
static AsyncRequestBody fromByteBuffers(ByteBuffer... byteBuffers) {
return SimpleAsyncRequestBody.of(null, byteBuffers);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
* @see AsyncRequestBody#fromString(String)
*/
@SdkInternalApi
@Deprecated
public final class ByteArrayAsyncRequestBody implements AsyncRequestBody {
private static final Logger log = Logger.loggerFor(ByteArrayAsyncRequestBody.class);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.core.internal.async;

import java.nio.ByteBuffer;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.internal.util.Mimetype;
import software.amazon.awssdk.utils.BinaryUtils;
import software.amazon.awssdk.utils.Logger;

/**
* An implementation of {@link AsyncRequestBody} for providing data from the supplied {@link ByteBuffer} arrau. This is created
* using static
* methods on {@link AsyncRequestBody}
*
* @see AsyncRequestBody#fromBytes(byte[])
* @see AsyncRequestBody#fromByteBuffer(ByteBuffer)
* @see AsyncRequestBody#fromString(String)
*/
@SdkInternalApi
public final class SimpleAsyncRequestBody implements AsyncRequestBody {
private static final Logger log = Logger.loggerFor(SimpleAsyncRequestBody.class);

private final String mimetype;
private final Long length;
private final ByteBuffer[] buffers;

private SimpleAsyncRequestBody(String mimetype, Long length, ByteBuffer... buffers) {
this.mimetype = mimetype;
this.length = length;
this.buffers = buffers;
}

@Override
public Optional<Long> contentLength() {
return Optional.ofNullable(length);
}

@Override
public String contentType() {
return mimetype;
}

@Override
public void subscribe(Subscriber<? super ByteBuffer> s) {
// As per rule 1.9 we must throw NullPointerException if the subscriber parameter is null
if (s == null) {
throw new NullPointerException("Subscription MUST NOT be null.");
}

// As per 2.13, this method must return normally (i.e. not throw).
try {
s.onSubscribe(
new Subscription() {
private final AtomicInteger index = new AtomicInteger(0);
private final AtomicBoolean completed = new AtomicBoolean(false);

@Override
public void request(long n) {
if (completed.get()) {
return;
}

if (n > 0) {
int i = index.getAndIncrement();

if (i >= buffers.length) {
return;
}

long remaining = n;

do {
ByteBuffer buffer = buffers[i];
if (!buffer.hasArray()) {
buffer = ByteBuffer.wrap(BinaryUtils.copyBytesFrom(buffer));
}
s.onNext(buffer);
remaining--;
} while (remaining > 0 && (i = index.getAndIncrement()) < buffers.length);

if (i >= buffers.length - 1 && completed.compareAndSet(false, true)) {
s.onComplete();
}
} else {
s.onError(new IllegalArgumentException("§3.9: non-positive requests are not allowed!"));
}
}

@Override
public void cancel() {
completed.set(true);
}
}
);
} catch (Throwable ex) {
log.error(() -> s + " violated the Reactive Streams rule 2.13 by throwing an exception from onSubscribe.", ex);
}
}

public static SimpleAsyncRequestBody of(Long length, ByteBuffer... buffers) {
return new SimpleAsyncRequestBody(Mimetype.MIMETYPE_OCTET_STREAM, length, buffers);
}

public static SimpleAsyncRequestBody of(String mimetype, Long length, ByteBuffer... buffers) {
return new SimpleAsyncRequestBody(mimetype, length, buffers);
}

public static SimpleAsyncRequestBody of(byte[] bytes) {
return new SimpleAsyncRequestBody(Mimetype.MIMETYPE_OCTET_STREAM, (long) bytes.length,
ByteBuffer.wrap(bytes));
}

public static SimpleAsyncRequestBody of(String mimetype, byte[] bytes) {
return new SimpleAsyncRequestBody(mimetype, (long) bytes.length, ByteBuffer.wrap(bytes));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.core.internal.async;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;
import org.junit.jupiter.api.Test;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.utils.BinaryUtils;

class SimpleAsyncRequestBodyTest {

private static class TestSubscriber implements Subscriber<ByteBuffer> {
private Subscription subscription;
private boolean onCompleteCalled = false;
private int callsToComplete = 0;
private final List<ByteBuffer> publishedResults = Collections.synchronizedList(new ArrayList<>());

public void request(long n) {
subscription.request(n);
}

@Override
public void onSubscribe(Subscription s) {
this.subscription = s;
}

@Override
public void onNext(ByteBuffer byteBuffer) {
publishedResults.add(byteBuffer);
}

@Override
public void onError(Throwable throwable) {
throw new IllegalStateException(throwable);
}

@Override
public void onComplete() {
onCompleteCalled = true;
callsToComplete++;
}
}

@Test
public void subscriberIsMarkedAsCompleted() {
AsyncRequestBody requestBody = SimpleAsyncRequestBody.of("Hello World!".getBytes(StandardCharsets.UTF_8));

TestSubscriber subscriber = new TestSubscriber();
requestBody.subscribe(subscriber);
subscriber.request(1);

assertTrue(subscriber.onCompleteCalled);
assertEquals(1, subscriber.publishedResults.size());
}

@Test
public void subscriberIsMarkedAsCompletedWhenARequestIsMadeForMoreBuffersThanAreAvailable() {
AsyncRequestBody requestBody = SimpleAsyncRequestBody.of("Hello World!".getBytes(StandardCharsets.UTF_8));

TestSubscriber subscriber = new TestSubscriber();
requestBody.subscribe(subscriber);
subscriber.request(2);

assertTrue(subscriber.onCompleteCalled);
assertEquals(1, subscriber.publishedResults.size());
}

@Test
public void subscriberIsThreadSafeAndMarkedAsCompletedExactlyOnce() throws InterruptedException {
int numBuffers = 100;
AsyncRequestBody requestBody = SimpleAsyncRequestBody.of(null, IntStream.range(0, numBuffers)
.mapToObj(i -> ByteBuffer.wrap(new byte[1]))
.toArray(ByteBuffer[]::new));

TestSubscriber subscriber = new TestSubscriber();
requestBody.subscribe(subscriber);

int parallelism = 8;
ExecutorService executorService = Executors.newFixedThreadPool(parallelism);
for (int i = 0; i < parallelism; i++) {
executorService.submit(() -> {
for (int j = 0; j < numBuffers; j++) {
subscriber.request(2);
}
});
}
executorService.shutdown();
executorService.awaitTermination(1, TimeUnit.MINUTES);

assertTrue(subscriber.onCompleteCalled);
assertEquals(1, subscriber.callsToComplete);
assertEquals(numBuffers, subscriber.publishedResults.size());
}

@Test
public void subscriberIsNotMarkedAsCompletedWhenThereAreRemainingBuffersToPublish() {
byte[] helloWorld = "Hello World!".getBytes(StandardCharsets.UTF_8);
byte[] goodbyeWorld = "Goodbye World!".getBytes(StandardCharsets.UTF_8);
AsyncRequestBody requestBody = SimpleAsyncRequestBody.of((long) (helloWorld.length + goodbyeWorld.length),
ByteBuffer.wrap(helloWorld),
ByteBuffer.wrap(goodbyeWorld));

TestSubscriber subscriber = new TestSubscriber();
requestBody.subscribe(subscriber);
subscriber.request(1);

assertFalse(subscriber.onCompleteCalled);
assertEquals(1, subscriber.publishedResults.size());
}

@Test
public void subscriberReceivesAllBuffers() {
byte[] helloWorld = "Hello World!".getBytes(StandardCharsets.UTF_8);
byte[] goodbyeWorld = "Goodbye World!".getBytes(StandardCharsets.UTF_8);

AsyncRequestBody requestBody = SimpleAsyncRequestBody.of((long) (helloWorld.length + goodbyeWorld.length),
ByteBuffer.wrap(helloWorld),
ByteBuffer.wrap(goodbyeWorld));

TestSubscriber subscriber = new TestSubscriber();
requestBody.subscribe(subscriber);
subscriber.request(2);

assertEquals(2, subscriber.publishedResults.size());
assertTrue(subscriber.onCompleteCalled);
assertArrayEquals(helloWorld, BinaryUtils.copyAllBytesFrom(subscriber.publishedResults.get(0)));
assertArrayEquals(goodbyeWorld, BinaryUtils.copyAllBytesFrom(subscriber.publishedResults.get(1)));
}

@Test
public void canceledSubscriberDoesNotReturnNewResults() {
AsyncRequestBody requestBody = SimpleAsyncRequestBody.of(null, ByteBuffer.wrap(new byte[0]));

TestSubscriber subscriber = new TestSubscriber();
requestBody.subscribe(subscriber);

subscriber.subscription.cancel();
subscriber.request(1);

assertTrue(subscriber.publishedResults.isEmpty());
}

}

0 comments on commit 1f41881

Please sign in to comment.