Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Improve asynchronous test cases format #3601

Merged
merged 6 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,9 @@

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.lang.management.GarbageCollectorMXBean;
import java.lang.management.ManagementFactory;
import java.lang.management.MemoryPoolMXBean;
import java.lang.management.MemoryUsage;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.zip.GZIPOutputStream;
Expand All @@ -35,12 +27,14 @@
import org.apache.hc.core5.http.ContentType;
import org.apache.hc.core5.http.io.entity.ByteArrayEntity;
import org.apache.hc.core5.http.message.BasicHeader;
import org.apache.http.HttpStatus;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.client.Client;
import org.opensearch.test.framework.AsyncActions;
import org.opensearch.test.framework.TestSecurityConfig;
import org.opensearch.test.framework.TestSecurityConfig.User;
import org.opensearch.test.framework.cluster.ClusterManager;
Expand Down Expand Up @@ -93,9 +87,8 @@ public void testUnauthenticatedFewBig() {
final String requestPath = "/*/_search";
final int parrallelism = 5;
final int totalNumberOfRequests = 100;
final boolean statsPrinter = false;

runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter);
runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests);
peternied marked this conversation as resolved.
Show resolved Hide resolved
}

@Test
Expand All @@ -105,9 +98,8 @@ public void testUnauthenticatedManyMedium() {
final String requestPath = "/*/_search";
final int parrallelism = 20;
final int totalNumberOfRequests = 10_000;
final boolean statsPrinter = false;

runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter);
runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests);
}

@Test
Expand All @@ -117,61 +109,26 @@ public void testUnauthenticatedTonsSmall() {
final String requestPath = "/*/_search";
final int parrallelism = 100;
final int totalNumberOfRequests = 1_000_000;
final boolean statsPrinter = false;

runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter);
runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests);
}

private Long runResourceTest(
private void runResourceTest(
final RequestBodySize size,
final String requestPath,
final int parrallelism,
final int totalNumberOfRequests,
final boolean statsPrinter
final int totalNumberOfRequests
) {
final byte[] compressedRequestBody = createCompressedRequestBody(size);
try (final TestRestClient client = cluster.getRestClient(new BasicHeader("Content-Encoding", "gzip"))) {

if (statsPrinter) {
printStats();
}
final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath);
post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON));

final ForkJoinPool forkJoinPool = new ForkJoinPool(parrallelism);

final List<CompletableFuture<Void>> waitingOn = IntStream.rangeClosed(1, totalNumberOfRequests)
.boxed()
.map(i -> CompletableFuture.runAsync(() -> client.executeRequest(post), forkJoinPool))
.collect(Collectors.toList());
Supplier<Long> getCount = () -> waitingOn.stream().filter(cf -> cf.isDone() && !cf.isCompletedExceptionally()).count();

CompletableFuture<Void> statPrinter = statsPrinter ? CompletableFuture.runAsync(() -> {
while (true) {
printStats();
System.err.println(" & Succesful completions: " + getCount.get());
try {
Thread.sleep(500);
} catch (Exception e) {
break;
}
}
}, forkJoinPool) : CompletableFuture.completedFuture(null);

final CompletableFuture<Void> allOfThem = CompletableFuture.allOf(waitingOn.toArray(new CompletableFuture[0]));

try {
allOfThem.get(30, TimeUnit.SECONDS);
statPrinter.cancel(true);
} catch (final Exception e) {
// Ignored
}

if (statsPrinter) {
printStats();
System.err.println(" & Succesful completions: " + getCount.get());
}
return getCount.get();
final var requests = AsyncActions.generate(() -> {
final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath);
post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON));
return client.executeRequest(post);
}, parrallelism, totalNumberOfRequests);

AsyncActions.getAll(requests, 30, TimeUnit.SECONDS)
.forEach((response) -> { response.assertStatusCode(HttpStatus.SC_UNAUTHORIZED); });
}
}

Expand Down Expand Up @@ -231,37 +188,4 @@ private byte[] createCompressedRequestBody(final RequestBodySize size) {
throw new RuntimeException(ioe);
}
}

DarshitChanpura marked this conversation as resolved.
Show resolved Hide resolved
private void printStats() {
peternied marked this conversation as resolved.
Show resolved Hide resolved
System.err.println("** Stats ");
printMemory();
printMemoryPools();
printGCPools();
}

private void printMemory() {
final Runtime runtime = Runtime.getRuntime();

final long totalMemory = runtime.totalMemory(); // Total allocated memory
final long freeMemory = runtime.freeMemory(); // Amount of free memory
final long usedMemory = totalMemory - freeMemory; // Amount of used memory

System.err.println(" Memory Total: " + totalMemory + " Free:" + freeMemory + " Used:" + usedMemory);
}

private void printMemoryPools() {
List<MemoryPoolMXBean> memoryPools = ManagementFactory.getMemoryPoolMXBeans();
for (MemoryPoolMXBean memoryPool : memoryPools) {
MemoryUsage usage = memoryPool.getUsage();
System.err.println(" " + memoryPool.getName() + " USED: " + usage.getUsed() + " MAX: " + usage.getMax());
}
}

private void printGCPools() {
List<GarbageCollectorMXBean> garbageCollectors = ManagementFactory.getGarbageCollectorMXBeans();
for (GarbageCollectorMXBean garbageCollector : garbageCollectors) {
System.err.println(" " + garbageCollector.getName() + " COLLECTION TIME: " + garbageCollector.getCollectionTime());
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,17 @@
package org.opensearch.security.rest;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;

import org.apache.hc.client5.http.classic.methods.HttpPost;
import org.apache.hc.core5.http.ContentType;
import org.apache.hc.core5.http.Header;
import org.apache.hc.core5.http.HttpStatus;
import org.apache.hc.core5.http.io.entity.ByteArrayEntity;
import org.apache.hc.core5.http.message.BasicHeader;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.runner.RunWith;

import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.CoreMatchers.anyOf;
import static org.hamcrest.MatcherAssert.assertThat;
import org.opensearch.test.framework.AsyncActions;
import org.opensearch.test.framework.TestSecurityConfig;
import org.opensearch.test.framework.cluster.ClusterManager;
import org.opensearch.test.framework.cluster.LocalCluster;
Expand All @@ -34,15 +30,14 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.zip.GZIPOutputStream;
import org.opensearch.test.framework.cluster.TestRestClient.HttpResponse;

import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.AUTHC_HTTPBASIC_INTERNAL;
import static org.opensearch.test.framework.TestSecurityConfig.Role.ALL_ACCESS;
import static org.opensearch.test.framework.cluster.TestRestClientConfiguration.getBasicAuthHeader;
Expand All @@ -60,7 +55,7 @@ public class CompressionTests {
.build();

@Test
public void testAuthenticatedGzippedRequests() throws Exception {
public void testAuthenticatedGzippedRequests() {
final String requestPath = "/*/_search";
final int parallelism = 10;
final int totalNumberOfRequests = 100;
Expand All @@ -69,72 +64,54 @@ public void testAuthenticatedGzippedRequests() throws Exception {

final byte[] compressedRequestBody = createCompressedRequestBody(rawBody);
try (final TestRestClient client = cluster.getRestClient(ADMIN_USER, new BasicHeader("Content-Encoding", "gzip"))) {
final var requests = AsyncActions.generate(() -> {
final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath);
post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON));
return client.executeRequest(post);
}, parallelism, totalNumberOfRequests);

final ForkJoinPool forkJoinPool = new ForkJoinPool(parallelism);

final List<CompletableFuture<HttpResponse>> waitingOn = IntStream.rangeClosed(1, totalNumberOfRequests)
.boxed()
.map(i -> CompletableFuture.supplyAsync(() -> {
final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath);
post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON));
return client.executeRequest(post);
}, forkJoinPool))
.collect(Collectors.toList());

final CompletableFuture<Void> allOfThem = CompletableFuture.allOf(waitingOn.toArray(new CompletableFuture[0]));

allOfThem.get(30, TimeUnit.SECONDS);

waitingOn.stream().forEach(future -> {
try {
final HttpResponse response = future.get();
response.assertStatusCode(HttpStatus.SC_OK);
} catch (final Exception ex) {
throw new RuntimeException(ex);
}
});
;
AsyncActions.getAll(requests, 30, TimeUnit.SECONDS).forEach((response) -> { response.assertStatusCode(HttpStatus.SC_OK); });
}
}

@Test
public void testMixOfAuthenticatedAndUnauthenticatedGzippedRequests() throws Exception {
final String requestPath = "/*/_search";
final int parallelism = 10;
final int totalNumberOfRequests = 100;
final int totalNumberOfRequests = 50;

final String rawBody = "{ \"query\": { \"match\": { \"foo\": \"bar\" }}}";

final byte[] compressedRequestBody = createCompressedRequestBody(rawBody);
try (final TestRestClient client = cluster.getRestClient(new BasicHeader("Content-Encoding", "gzip"))) {

final ForkJoinPool forkJoinPool = new ForkJoinPool(parallelism);

final Header basicAuthHeader = getBasicAuthHeader(ADMIN_USER.getName(), ADMIN_USER.getPassword());

final List<CompletableFuture<HttpResponse>> waitingOn = IntStream.rangeClosed(1, totalNumberOfRequests)
.boxed()
.map(i -> CompletableFuture.supplyAsync(() -> {
final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath);
post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON));
return i % 2 == 0 ? client.executeRequest(post) : client.executeRequest(post, basicAuthHeader);
}, forkJoinPool))
.collect(Collectors.toList());

final CompletableFuture<Void> allOfThem = CompletableFuture.allOf(waitingOn.toArray(new CompletableFuture[0]));

allOfThem.get(30, TimeUnit.SECONDS);

waitingOn.stream().forEach(future -> {
try {
final HttpResponse response = future.get();
assertThat(response.getBody(), not(containsString("json_parse_exception")));
assertThat(response.getStatusCode(), anyOf(equalTo(HttpStatus.SC_UNAUTHORIZED), equalTo(HttpStatus.SC_OK)));
} catch (final Exception ex) {
throw new RuntimeException(ex);
}
final CountDownLatch countDownLatch = new CountDownLatch(1);

final var authorizedRequests = AsyncActions.generate(() -> {
countDownLatch.await();
System.err.println("Generation triggerd authorizedRequests");
peternied marked this conversation as resolved.
Show resolved Hide resolved
final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath);
post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON));
return client.executeRequest(post, getBasicAuthHeader(ADMIN_USER.getName(), ADMIN_USER.getPassword()));
}, parallelism, totalNumberOfRequests);

final var unauthorizedRequests = AsyncActions.generate(() -> {
countDownLatch.await();
System.err.println("Generation triggerd unauthorizedRequests");
final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath);
post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON));
return client.executeRequest(post);
}, parallelism, totalNumberOfRequests);

// Make sure all requests start at the same time
countDownLatch.countDown();

AsyncActions.getAll(authorizedRequests, 30, TimeUnit.SECONDS).forEach((response) -> {
peternied marked this conversation as resolved.
Show resolved Hide resolved
assertThat(response.getStatusCode(), equalTo(HttpStatus.SC_OK));
});
AsyncActions.getAll(unauthorizedRequests, 30, TimeUnit.SECONDS).forEach((response) -> {
assertThat(response.getBody(), not(containsString("json_parse_exception")));
assertThat(response.getStatusCode(), equalTo(HttpStatus.SC_UNAUTHORIZED));
});
;
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
*/

package org.opensearch.test.framework;

import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class AsyncActions {

/**
* Using the provided generator create a list of completable futures.
* @param parrallelism How many calls to the generator should be done at the same time.
* @param generationCount The total number of calls to the generator to conduct.
* @return The list of completable futures running on the fork join thread pool.
*/
public static <T> List<CompletableFuture<T>> generate(final Callable<T> generator, final int parrallelism, final int generationCount) {
final ForkJoinPool forkJoinPool = new ForkJoinPool(parrallelism);
return IntStream.rangeClosed(1, generationCount).boxed().map(i -> CompletableFuture.supplyAsync(() -> {
try {
return generator.call();
} catch (final Exception ex) {
throw new RuntimeException(ex);
}
}, forkJoinPool)).collect(Collectors.toList());
}

/**
* Waits for futures for a time period and then returns them a list
* @param futures Futures to wait for completion with a result
* @param n Amount of time to wait
* @param unit Time associated with those units
* @return Completed results from the futures
*/
public static <T> List<T> getAll(final List<CompletableFuture<T>> futures, final int n, final TimeUnit unit) {
final CompletableFuture<Void> futuresCompleted = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
try {
futuresCompleted.get(n, unit);
} catch (final Exception ex) {
final long completedFutures = futures.stream().filter(CompletableFuture::isDone).count();
throw new RuntimeException("Unable to wait for all futures to compete, " + completedFutures + " have finished.", ex);
}

return futures.stream().map(future -> {
try {
return future.get();
} catch (final Exception ex) {
throw new RuntimeException(ex);
}
}).collect(Collectors.toList());
}
}