Skip to content

Commit

Permalink
Add StatusRequest, Waiter interface
Browse files Browse the repository at this point in the history
Signed-off-by: owenhalpert <ohalpert@gmail.com>
  • Loading branch information
owenhalpert committed Mar 6, 2025
1 parent 1a9577b commit acd6118
Show file tree
Hide file tree
Showing 12 changed files with 184 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.StringUtils;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.common.CheckedTriFunction;
import org.opensearch.common.StreamContext;
Expand Down Expand Up @@ -206,7 +207,7 @@ private CheckedTriFunction<Integer, Long, Long, InputStreamContainer, IOExceptio

@Override
public void readFromRepository(String fileName, IndexOutputWithBuffer indexOutputWithBuffer) throws IOException {
if (fileName == null || fileName.isEmpty()) {
if (StringUtils.isBlank(fileName)) {
throw new IllegalArgumentException("download path is null or empty");
}
if (!fileName.endsWith(KNNEngine.FAISS.getExtension())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.remote.RemoteBuildRequest;
import org.opensearch.knn.index.remote.RemoteBuildResponse;
import org.opensearch.knn.index.remote.RemoteBuildStatusRequest;
import org.opensearch.knn.index.remote.RemoteBuildStatusResponse;
import org.opensearch.knn.index.remote.RemoteIndexClient;
import org.opensearch.knn.index.remote.RemoteIndexClientFactory;
import org.opensearch.knn.index.remote.RemoteIndexWaiter;
import org.opensearch.knn.index.remote.RemoteIndexWaiterFactory;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.repositories.Repository;
import org.opensearch.repositories.RepositoryMissingException;
Expand Down Expand Up @@ -132,19 +135,21 @@ public void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException {
log.debug("Repository write took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName());

RemoteIndexClient client = RemoteIndexClientFactory.getRemoteIndexClient();
RemoteBuildRequest request = new RemoteBuildRequest(
RemoteBuildRequest buildRequest = new RemoteBuildRequest(
indexSettings,
indexInfo,
repository.getMetadata(),
blobPath.buildAsString()
);
stopWatch = new StopWatch().start();
RemoteBuildResponse remoteBuildResponse = client.submitVectorBuild(request);
RemoteBuildResponse remoteBuildResponse = client.submitVectorBuild(buildRequest);
time_in_millis = stopWatch.stop().totalTime().millis();
log.debug("Submit vector build took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName());

RemoteBuildStatusRequest remoteBuildStatusRequest = new RemoteBuildStatusRequest(remoteBuildResponse);
RemoteIndexWaiter waiter = RemoteIndexWaiterFactory.getRemoteIndexWaiter(client);
stopWatch = new StopWatch().start();
RemoteBuildStatusResponse remoteBuildStatusResponse = client.awaitVectorBuild(remoteBuildResponse);
RemoteBuildStatusResponse remoteBuildStatusResponse = waiter.awaitVectorBuild(remoteBuildStatusRequest);
time_in_millis = stopWatch.stop().totalTime().millis();
log.debug("Await vector build took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.ENCODER_FLAT;
import static org.opensearch.knn.common.KNNConstants.FAISS_HNSW_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
Expand Down Expand Up @@ -203,7 +204,7 @@ private static int getMFromIndexDescription(String indexDescription) {
static boolean supportsRemoteIndexBuild(Map<String, String> attributes) throws IOException {
String parametersJson = attributes.get("parameters");
String encoderName = getEncoderName(parametersJson);
return "flat".equals(encoderName);
return ENCODER_FLAT.equals(encoderName);
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.remote;

import lombok.Getter;

/**
* Request object to extract and wrap necessary parameters for the `/_status` API.
*/
@Getter
public class RemoteBuildStatusRequest {
private final String jobId;

public RemoteBuildStatusRequest(RemoteBuildResponse remoteBuildResponse) {
this.jobId = remoteBuildResponse.getJobId();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@

import lombok.Builder;
import lombok.Value;
import org.apache.commons.lang.StringUtils;
import org.opensearch.core.ParseField;
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;

import static org.opensearch.knn.index.remote.KNNRemoteConstants.COMPLETED_INDEX_BUILD;
import static org.opensearch.knn.index.remote.KNNRemoteConstants.ERROR_MESSAGE;
import static org.opensearch.knn.index.remote.KNNRemoteConstants.FILE_NAME;
import static org.opensearch.knn.index.remote.KNNRemoteConstants.TASK_STATUS;
Expand Down Expand Up @@ -54,12 +52,6 @@ static RemoteBuildStatusResponse fromXContent(XContentParser parser) throws IOEx
}
}
}
if (StringUtils.isBlank(builder.taskStatus)) {
throw new IOException("Invalid response format, missing " + TASK_STATUS);
}
if (COMPLETED_INDEX_BUILD.equals(builder.taskStatus) && StringUtils.isBlank(builder.fileName)) {
throw new IOException("Invalid response format, missing " + FILE_NAME + " for completed status");
}
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,15 @@ public interface RemoteIndexClient {
/**
* Submit a build to the Remote Vector Build Service.
* @return RemoteBuildResponse from the server
*/
RemoteBuildResponse submitVectorBuild(RemoteBuildRequest remoteBuildRequest) throws IOException;

/**
* Await the completion of the index build and for the server to return the path to the completed index
*
* @param remoteBuildResponse the /_build request response from the server
* @return remoteStatusResponse from the server
* @throws InterruptedException if the thread is interrupted while waiting for the build to complete
* @throws IOException if there is an error communicating with the server
*/
RemoteBuildStatusResponse awaitVectorBuild(RemoteBuildResponse remoteBuildResponse) throws InterruptedException, IOException;
RemoteBuildResponse submitVectorBuild(RemoteBuildRequest remoteBuildRequest) throws IOException;

/**
* Get the status of the index build
*
* @param remoteBuildResponse the /_build request response from the server
* Get the status of an index build
* @param remoteBuildStatusRequest the status request object containing the job ID to check
* @return remoteStatusResponse from the server
* @throws IOException if there is an error communicating with the server
*/
RemoteBuildStatusResponse getBuildStatus(RemoteBuildResponse remoteBuildResponse) throws IOException;
RemoteBuildStatusResponse getBuildStatus(RemoteBuildStatusRequest remoteBuildStatusRequest) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ private static CloseableHttpClient createHttpClient() {
* Get the Singleton shared HTTP client
* @return The static HTTP Client
*/
protected static CloseableHttpClient getHttpClient() {
static CloseableHttpClient getHttpClient() {
return HttpClientHolder.httpClient;
}

Expand Down Expand Up @@ -123,24 +123,13 @@ private HttpPost getHttpPost(String jsonRequest) {
return buildRequest;
}

/**
* Await the completion of the index build using a {@link RemoteIndexPoller}.
* @param remoteBuildResponse containing job_id from the server response used to track the job
* @return RemoteBuildStatusResponse containing the path to the completed index
*/
@Override
public RemoteBuildStatusResponse awaitVectorBuild(RemoteBuildResponse remoteBuildResponse) throws InterruptedException, IOException {
RemoteIndexPoller remoteIndexPoller = new RemoteIndexPoller(this);
return remoteIndexPoller.pollRemoteEndpoint(remoteBuildResponse);
}

/**
* Helper method to directly get the status response for a given build
* @param remoteBuildResponse containing job ID to check
* @param remoteBuildStatusRequest containing job ID to check
* @return The entire response for the status request
*/
public RemoteBuildStatusResponse getBuildStatus(RemoteBuildResponse remoteBuildResponse) throws IOException {
String jobId = remoteBuildResponse.getJobId();
public RemoteBuildStatusResponse getBuildStatus(RemoteBuildStatusRequest remoteBuildStatusRequest) throws IOException {
String jobId = remoteBuildStatusRequest.getJobId();
HttpGet request = new HttpGet(endpoint + STATUS_ENDPOINT + "/" + jobId);
if (authHeader != null) {
request.setHeader(HttpHeaders.AUTHORIZATION, authHeader);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,26 @@

package org.opensearch.knn.index.remote;

import org.apache.commons.lang.StringUtils;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.knn.index.KNNSettings;

import java.io.IOException;
import java.time.Duration;

import static org.opensearch.knn.index.remote.KNNRemoteConstants.COMPLETED_INDEX_BUILD;
import static org.opensearch.knn.index.remote.KNNRemoteConstants.FAILED_INDEX_BUILD;
import static org.opensearch.knn.index.remote.KNNRemoteConstants.FILE_NAME;
import static org.opensearch.knn.index.remote.KNNRemoteConstants.RUNNING_INDEX_BUILD;
import static org.opensearch.knn.index.remote.KNNRemoteConstants.TASK_STATUS;

/**
* Implementation of a {@link RemoteIndexWaiter} that awaits the vector build by polling.
*/
class RemoteIndexPoller implements RemoteIndexWaiter {
// The poller waits KNN_REMOTE_BUILD_CLIENT_POLL_INTERVAL * INITIAL_DELAY_FACTOR before sending the first status request
private static final int INITIAL_DELAY_FACTOR = 3;

class RemoteIndexPoller {
private final RemoteIndexClient client;

RemoteIndexPoller(RemoteIndexClient client) {
Expand All @@ -24,38 +34,55 @@ class RemoteIndexPoller {
/**
* Polls the remote endpoint for the status of the build job until timeout.
*
* @param remoteBuildResponse The response from the initial build request
* @param remoteBuildStatusRequest The response from the initial build request
* @return RemoteBuildStatusResponse containing the path of the completed build job
* @throws InterruptedException if the thread is interrupted while polling
* @throws IOException if an I/O error occurs
*/
@SuppressWarnings("BusyWait")
RemoteBuildStatusResponse pollRemoteEndpoint(RemoteBuildResponse remoteBuildResponse) throws InterruptedException, IOException {
public RemoteBuildStatusResponse awaitVectorBuild(RemoteBuildStatusRequest remoteBuildStatusRequest) throws InterruptedException,
IOException {
long startTime = System.currentTimeMillis();
long timeout = ((TimeValue) KNNSettings.state().getSettingValue(KNNSettings.KNN_REMOTE_BUILD_CLIENT_TIMEOUT)).getMillis();
long pollInterval = ((TimeValue) (KNNSettings.state().getSettingValue(KNNSettings.KNN_REMOTE_BUILD_CLIENT_POLL_INTERVAL)))
.getMillis();

// Initial delay to allow build service to process the job and store the ID before getting its status.
// TODO tune default based on benchmarking
Thread.sleep(pollInterval * 3);
Thread.sleep(pollInterval * INITIAL_DELAY_FACTOR);

while (System.currentTimeMillis() - startTime < timeout) {
RemoteBuildStatusResponse remoteBuildStatusResponse = client.getBuildStatus(remoteBuildResponse);
RemoteBuildStatusResponse remoteBuildStatusResponse = client.getBuildStatus(remoteBuildStatusRequest);
Duration d = Duration.ofMillis(System.currentTimeMillis() - startTime);
String taskStatus = remoteBuildStatusResponse.getTaskStatus();
if (StringUtils.isBlank(taskStatus)) {
throw new IOException(String.format("Invalid response format, missing %s", TASK_STATUS));
}
switch (taskStatus) {
case COMPLETED_INDEX_BUILD:
case COMPLETED_INDEX_BUILD -> {
if (StringUtils.isBlank(remoteBuildStatusResponse.getFileName())) {
throw new IOException(String.format("Invalid response format, missing %s for %s status", FILE_NAME, taskStatus));
}
return remoteBuildStatusResponse;
case FAILED_INDEX_BUILD:
}
case FAILED_INDEX_BUILD -> {
String errorMessage = remoteBuildStatusResponse.getErrorMessage();
if (errorMessage != null) {
throw new InterruptedException("Index build failed: " + errorMessage);
}
throw new InterruptedException("Index build failed without an error message.");
case RUNNING_INDEX_BUILD:
Thread.sleep(pollInterval);
throw new InterruptedException(
String.format("Remote index build failed after %d minutes. %s", d.toMinutesPart(), errorMessage)
);
}
case RUNNING_INDEX_BUILD -> Thread.sleep(pollInterval);
default -> throw new IOException(String.format("Server returned invalid task status %s", taskStatus));
}
}
throw new InterruptedException("Build timed out, falling back to CPU build.");
Duration waitedDuration = Duration.ofMillis(System.currentTimeMillis() - startTime);
Duration timeoutDuration = Duration.ofMillis(timeout);
throw new InterruptedException(
String.format(
"Remote index build timed out after %d minutes, timeout is set to %d minutes. Falling back to CPU build",
waitedDuration.toMinutesPart(),
timeoutDuration.toMinutesPart()
)
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.remote;

import java.io.IOException;

public interface RemoteIndexWaiter {

/**
* Wait for the remote index to be built and return its response when completed
* @param remoteBuildStatusRequest the status request object
* @return remoteStatusResponse from the server
* @throws InterruptedException if the waiting process gets interrupted or build fails
* @throws IOException if there is an error communicating with the server
*/
RemoteBuildStatusResponse awaitVectorBuild(RemoteBuildStatusRequest remoteBuildStatusRequest) throws InterruptedException, IOException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.remote;

public class RemoteIndexWaiterFactory {
// Default to poller
public static RemoteIndexWaiter getRemoteIndexWaiter(RemoteIndexClient client) {
return new RemoteIndexPoller(client);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,28 +57,6 @@ public class RemoteBuildStatusResponseTests extends KNNTestCase {
+ "\":"
+ NULL
+ "}";
public static final String MISSING_TASK_STATUS_RESPONSE = "{"
+ "\""
+ FILE_NAME
+ "\":\""
+ MOCK_FILE_NAME
+ "\","
+ "\""
+ ERROR_MESSAGE
+ "\":"
+ NULL
+ "}";
public static final String MISSING_FILE_NAME_RESPONSE = "{"
+ "\""
+ TASK_STATUS
+ "\":\""
+ COMPLETED_INDEX_BUILD
+ "\","
+ "\""
+ ERROR_MESSAGE
+ "\":"
+ NULL
+ "}";

public void testSuccessfulBuildStatusResponse() throws IOException {
try (
Expand All @@ -96,32 +74,6 @@ public void testSuccessfulBuildStatusResponse() throws IOException {
}
}

public void testMissingTaskStatus() throws IOException {
try (
XContentParser parser = JsonXContent.jsonXContent.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
MISSING_TASK_STATUS_RESPONSE
)
) {
IOException exception = assertThrows(IOException.class, () -> RemoteBuildStatusResponse.fromXContent(parser));
assertEquals("Invalid response format, missing " + TASK_STATUS, exception.getMessage());
}
}

public void testMissingIndexPathForCompletedStatus() throws IOException {
try (
XContentParser parser = JsonXContent.jsonXContent.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
MISSING_FILE_NAME_RESPONSE
)
) {
IOException exception = assertThrows(IOException.class, () -> RemoteBuildStatusResponse.fromXContent(parser));
assertEquals("Invalid response format, missing " + FILE_NAME + " for completed status", exception.getMessage());
}
}

public void testUnknownField() throws IOException {
try (
XContentParser parser = JsonXContent.jsonXContent.createParser(
Expand Down
Loading

0 comments on commit acd6118

Please sign in to comment.