Skip to content

Commit

Permalink
Add new unit tests, move some tests out of HTTPClient tests
Browse files Browse the repository at this point in the history
Signed-off-by: owenhalpert <ohalpert@gmail.com>
  • Loading branch information
owenhalpert committed Mar 4, 2025
1 parent 3e5746e commit 75b1afa
Show file tree
Hide file tree
Showing 9 changed files with 329 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public boolean supportsRemoteIndexBuild() {
@Override
public RemoteIndexParameters createRemoteIndexingParameters(Map<String, Object> indexInfoParameters) {
if (METHOD_HNSW.equals(indexInfoParameters.get(NAME))) {
return FaissHNSWMethod.getRemoteIndexingParameters(indexInfoParameters);
return FaissHNSWMethod.createRemoteIndexingParameters(indexInfoParameters);
}
throw new IllegalArgumentException("Unsupported method for remote indexing");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ protected Function<TrainingConfigValidationInput, TrainingConfigValidationOutput
* @param indexInfoParameters result of indexInfo.getParameters() to parse
* @return Map of parameters to be used as "index_parameters"
*/
public static RemoteIndexParameters getRemoteIndexingParameters(Map<String, Object> indexInfoParameters) {
public static RemoteIndexParameters createRemoteIndexingParameters(Map<String, Object> indexInfoParameters) {
RemoteFaissHNSWIndexParameters.RemoteFaissHNSWIndexParametersBuilder<?, ?> builder = RemoteFaissHNSWIndexParameters.builder();
assert (indexInfoParameters.get(SPACE_TYPE) instanceof String);
String spaceType = (String) indexInfoParameters.get(SPACE_TYPE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ private static class HttpClientHolder {
private static final CloseableHttpClient httpClient = createHttpClient();

private static CloseableHttpClient createHttpClient() {
return HttpClients.custom().setRetryStrategy(new RemoteIndexClientRetryStrategy()).build();
return HttpClients.custom().setRetryStrategy(new RemoteIndexHTTPClientRetryStrategy()).build();
}
}

Expand Down Expand Up @@ -112,7 +112,6 @@ public RemoteBuildResponse submitVectorBuild(RemoteBuildRequest remoteBuildReque
* Helper method to form the HttpPost request from the HTTPRemoteBuildRequest
* @param jsonRequest JSON converted request body to be submitted
* @return HttpPost request to be submitted
* @throws IOException if the request cannot be formed
*/
private HttpPost getHttpPost(String jsonRequest) {
HttpPost buildRequest = new HttpPost(URI.create(endpoint) + BUILD_ENDPOINT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
* TODO Future work: tune this retry strategy (MAX_RETRIES, BASE_DELAY_MS, exponential backoff/jitter) based on benchmarking
* @see org.apache.hc.client5.http.impl.DefaultHttpRequestRetryStrategy
*/
public class RemoteIndexClientRetryStrategy extends DefaultHttpRequestRetryStrategy {
public class RemoteIndexHTTPClientRetryStrategy extends DefaultHttpRequestRetryStrategy {
private static final int SC_BANDWIDTH_LIMIT_EXCEEDED = 509;
private static final int MAX_RETRIES = 1; // 2 total attempts
private static final long BASE_DELAY_MS = 1000;
Expand All @@ -45,7 +45,7 @@ public class RemoteIndexClientRetryStrategy extends DefaultHttpRequestRetryStrat
SC_BANDWIDTH_LIMIT_EXCEEDED
);

public RemoteIndexClientRetryStrategy() {
public RemoteIndexHTTPClientRetryStrategy() {
super(
MAX_RETRIES,
TimeValue.ofMilliseconds(BASE_DELAY_MS),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.remote;

import org.junit.Before;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.cluster.metadata.RepositoryMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.IndexSettings;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.test.OpenSearchSingleNodeTestCase;

import java.io.IOException;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
import static org.opensearch.knn.index.VectorDataType.FLOAT;
import static org.opensearch.knn.index.remote.KNNRemoteConstants.DOC_ID_FILE_EXTENSION;
import static org.opensearch.knn.index.remote.KNNRemoteConstants.S3;
import static org.opensearch.knn.index.remote.KNNRemoteConstants.VECTOR_BLOB_FILE_EXTENSION;
import static org.opensearch.knn.index.remote.RemoteIndexHTTPClientTests.MOCK_BLOB_NAME;
import static org.opensearch.knn.index.remote.RemoteIndexHTTPClientTests.TEST_BUCKET;
import static org.opensearch.knn.index.remote.RemoteIndexHTTPClientTests.TEST_CLUSTER;

public class RemoteBuildRequestTests extends OpenSearchSingleNodeTestCase {
@Mock
protected static ClusterService clusterService;

protected AutoCloseable openMocks;

@Before
public void setup() {
openMocks = MockitoAnnotations.openMocks(this);
clusterService = mock(ClusterService.class);
Set<Setting<?>> defaultClusterSettings = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
KNNSettings.state().setClusterService(clusterService);
when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, defaultClusterSettings));
}

/**
* Test the construction of the build request by comparing it to an explicitly created JSON object.
*/
public void testBuildRequest() {
RepositoryMetadata metadata = RemoteIndexHTTPClientTests.createTestRepositoryMetadata();
KNNSettings knnSettingsMock = mock(KNNSettings.class);
IndexSettings mockIndexSettings = RemoteIndexHTTPClientTests.createTestIndexSettings();

try (MockedStatic<KNNSettings> knnSettingsStaticMock = Mockito.mockStatic(KNNSettings.class)) {
knnSettingsStaticMock.when(KNNSettings::state).thenReturn(knnSettingsMock);
KNNSettings.state().setClusterService(clusterService);

BuildIndexParams indexInfo = RemoteIndexHTTPClientTests.createTestBuildIndexParams();

RemoteBuildRequest request = new RemoteBuildRequest(mockIndexSettings, indexInfo, metadata, MOCK_BLOB_NAME);

assertEquals(S3, request.getRepositoryType());
assertEquals(TEST_BUCKET, request.getContainerName());
assertEquals(FAISS_NAME, request.getEngine());
assertEquals(FLOAT.getValue(), request.getVectorDataType());
assertEquals(MOCK_BLOB_NAME + VECTOR_BLOB_FILE_EXTENSION, request.getVectorPath());
assertEquals(MOCK_BLOB_NAME + DOC_ID_FILE_EXTENSION, request.getDocIdPath());
assertEquals(TEST_CLUSTER, request.getTenantId());
assertEquals(2, request.getDocCount());
assertEquals(2, request.getDimension());

String expectedJson = "{"
+ "\"repository_type\":\"s3\","
+ "\"container_name\":\"test-bucket\","
+ "\"vector_path\":\"blob.knnvec\","
+ "\"doc_id_path\":\"blob.knndid\","
+ "\"tenant_id\":\"test-cluster\","
+ "\"dimension\":2,"
+ "\"doc_count\":2,"
+ "\"data_type\":\"float\","
+ "\"engine\":\"faiss\","
+ "\"index_parameters\":{"
+ "\"space_type\":\"l2\","
+ "\"algorithm\":\"hnsw\","
+ "\"algorithm_parameters\":{"
+ "\"ef_construction\":94,"
+ "\"ef_search\":89,"
+ "\"m\":14"
+ "}"
+ "}"
+ "}";
XContentParser expectedParser = JsonXContent.jsonXContent.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
expectedJson
);
Map<String, Object> expectedMap = expectedParser.map();

String jsonRequest;
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
request.toXContent(builder, ToXContentObject.EMPTY_PARAMS);
jsonRequest = builder.toString();
}

XContentParser generatedParser = JsonXContent.jsonXContent.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
jsonRequest
);
Map<String, Object> generatedMap = generatedParser.map();

assertEquals(expectedMap, generatedMap);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.remote;

import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.knn.KNNTestCase;

import java.io.IOException;

import static org.opensearch.knn.index.remote.RemoteIndexHTTPClientTests.MOCK_JOB_ID;
import static org.opensearch.knn.index.remote.RemoteIndexHTTPClientTests.MOCK_JOB_ID_RESPONSE;

public class RemoteBuildResponseTests extends KNNTestCase {
public void testRemoteBuildResponseParsing() throws IOException {
try (
XContentParser parser = JsonXContent.jsonXContent.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
MOCK_JOB_ID_RESPONSE
)
) {
RemoteBuildResponse response = RemoteBuildResponse.fromXContent(parser);
assertNotNull(response);
assertEquals(MOCK_JOB_ID, response.getJobId());
}
}

public void testRemoteBuildResponseParsingError() throws IOException {
String jsonResponse = "{\"error\":\"test-error\"}";
try (
XContentParser parser = JsonXContent.jsonXContent.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
jsonResponse
)
) {
assertThrows(IOException.class, () -> RemoteBuildResponse.fromXContent(parser));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.remote;

import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.common.KNNConstants;

import java.io.IOException;
import java.util.Map;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.opensearch.core.xcontent.DeprecationHandler.THROW_UNSUPPORTED_OPERATION;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.index.SpaceType.L2;

@SuppressWarnings("unchecked")
public class RemoteFaissHNSWIndexParametersTests extends KNNTestCase {
public void testToXContent() throws IOException {
RemoteFaissHNSWIndexParameters params = spy(
RemoteFaissHNSWIndexParameters.builder()
.spaceType(L2.getValue())
.algorithm(METHOD_HNSW)
.m(16)
.efConstruction(88)
.efSearch(99)
.build()
);

XContentBuilder builder = XContentFactory.jsonBuilder();
params.toXContent(builder, ToXContent.EMPTY_PARAMS);

try (
XContentParser parser = JsonXContent.jsonXContent.createParser(
NamedXContentRegistry.EMPTY,
THROW_UNSUPPORTED_OPERATION,
builder.toString()
)
) {
Map<String, Object> map = parser.map();

assertEquals(L2.getValue(), map.get(METHOD_PARAMETER_SPACE_TYPE));
assertEquals(METHOD_HNSW, map.get(KNNRemoteConstants.ALGORITHM));

Map<String, Object> algorithmParams = (Map<String, Object>) map.get(KNNRemoteConstants.ALGORITHM_PARAMETERS);
assertEquals(16, algorithmParams.get(KNNConstants.METHOD_PARAMETER_M));
assertEquals(88, algorithmParams.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION));
assertEquals(99, algorithmParams.get(KNNConstants.METHOD_PARAMETER_EF_SEARCH));
}

verify(params).addAlgorithmParameters(any(XContentBuilder.class));
}
}
Loading

0 comments on commit 75b1afa

Please sign in to comment.