Skip to content

Commit

Permalink
[ML] Adding new chat_completion task type for unified API (elastic#11…
Browse files Browse the repository at this point in the history
…9982)

* Creating new chat completion task type

* Adding some comments

* Refactoring names and removing todo

* Exposing chat completion for openai and eis for now

* Fixing tests
  • Loading branch information
jonathan-buttner authored Jan 15, 2025
1 parent d7474e6 commit f9a3721
Show file tree
Hide file tree
Showing 11 changed files with 283 additions and 29 deletions.
3 changes: 0 additions & 3 deletions muted-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,6 @@ tests:
- class: org.elasticsearch.search.profile.dfs.DfsProfilerIT
method: testProfileDfs
issue: https://github.com/elastic/elasticsearch/issues/119711
- class: org.elasticsearch.xpack.inference.InferenceCrudIT
method: testGetServicesWithCompletionTaskType
issue: https://github.com/elastic/elasticsearch/issues/119959
- class: org.elasticsearch.multi_cluster.MultiClusterYamlTestSuiteIT
issue: https://github.com/elastic/elasticsearch/issues/119983
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ public enum TaskType implements Writeable {
public boolean isAnyOrSame(TaskType other) {
return true;
}
};
},
CHAT_COMPLETION;

public static final String NAME = "task_type";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()
|| ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) {
assertThat(services.size(), equalTo(10));
} else {
assertThat(services.size(), equalTo(9));
}
assertThat(services.size(), equalTo(9));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -269,9 +264,30 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
)
);

assertArrayEquals(providers, providerList.toArray());
}

@SuppressWarnings("unchecked")
public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()
|| ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) {
providerList.add(6, "elastic");
assertThat(services.size(), equalTo(2));
} else {
assertThat(services.size(), equalTo(1));
}

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("service");
}

var providerList = new ArrayList<>(List.of("openai"));

if ((ElasticInferenceServiceFeature.DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()
|| ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled())) {
providerList.addFirst("elastic");
}

assertArrayEquals(providers, providerList.toArray());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ public final class Paths {
+ "}/{"
+ INFERENCE_ID
+ "}/_stream";
static final String UNIFIED_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/_unified";

public static final String UNIFIED_SUFFIX = "_unified";
static final String UNIFIED_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/" + UNIFIED_SUFFIX;
static final String UNIFIED_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{"
+ TASK_TYPE_OR_INFERENCE_ID
+ "}/{"
+ INFERENCE_ID
+ "}/_unified";
+ "}/"
+ UNIFIED_SUFFIX;

private Paths() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public void infer(

private static InferenceInputs createInput(Model model, List<String> input, @Nullable String query, boolean stream) {
return switch (model.getTaskType()) {
case COMPLETION -> new ChatCompletionInput(input, stream);
case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream);
case RERANK -> new QueryAndDocsInputs(query, input, stream);
case TEXT_EMBEDDING, SPARSE_EMBEDDING -> new DocumentsOnlyInput(input, stream);
default -> throw new ElasticsearchStatusException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.ENABLED;
import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.MAX_NUMBER_OF_ALLOCATIONS;
import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.MIN_NUMBER_OF_ALLOCATIONS;
import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_SUFFIX;
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;

public final class ServiceUtils {
Expand Down Expand Up @@ -780,5 +781,24 @@ public static void throwUnsupportedUnifiedCompletionOperation(String serviceName
throw new UnsupportedOperationException(Strings.format("The %s service does not support unified completion", serviceName));
}

public static String unsupportedTaskTypeForInference(Model model, EnumSet<TaskType> supportedTaskTypes) {
return Strings.format(
"Inference entity [%s] does not support task type [%s] for inference, the task type must be one of %s.",
model.getInferenceEntityId(),
model.getTaskType(),
supportedTaskTypes
);
}

public static String useChatCompletionUrlMessage(Model model) {
return org.elasticsearch.common.Strings.format(
"The task type for the inference entity is %s, please use the _inference/%s/%s/%s URL.",
model.getTaskType(),
model.getTaskType(),
model.getInferenceEntityId(),
UNIFIED_SUFFIX
);
}

private ServiceUtils() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
Expand All @@ -41,6 +42,7 @@
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
Expand All @@ -61,6 +63,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.useChatCompletionUrlMessage;

public class ElasticInferenceService extends SenderService {

Expand All @@ -69,8 +72,16 @@ public class ElasticInferenceService extends SenderService {

private final ElasticInferenceServiceComponents elasticInferenceServiceComponents;

private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.COMPLETION);
// The task types exposed via the _inference/_services API
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES_FOR_SERVICES_API = EnumSet.of(
TaskType.SPARSE_EMBEDDING,
TaskType.CHAT_COMPLETION
);
private static final String SERVICE_NAME = "Elastic";
/**
* The task types that the {@link InferenceAction.Request} can accept.
*/
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING);

public ElasticInferenceService(
HttpRequestSender.Factory factory,
Expand All @@ -83,7 +94,7 @@ public ElasticInferenceService(

@Override
public Set<TaskType> supportedStreamingTasks() {
return COMPLETION_ONLY;
return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY);
}

@Override
Expand Down Expand Up @@ -129,6 +140,15 @@ protected void doInfer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (SUPPORTED_INFERENCE_ACTION_TASK_TYPES.contains(model.getTaskType()) == false) {
var responseString = ServiceUtils.unsupportedTaskTypeForInference(model, SUPPORTED_INFERENCE_ACTION_TASK_TYPES);

if (model.getTaskType() == TaskType.CHAT_COMPLETION) {
responseString = responseString + " " + useChatCompletionUrlMessage(model);
}
listener.onFailure(new ElasticsearchStatusException(responseString, RestStatus.BAD_REQUEST));
}

if (model instanceof ElasticInferenceServiceExecutableActionModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
Expand Down Expand Up @@ -207,7 +227,7 @@ public InferenceServiceConfiguration getConfiguration() {

@Override
public EnumSet<TaskType> supportedTaskTypes() {
return supportedTaskTypes;
return SUPPORTED_TASK_TYPES_FOR_SERVICES_API;
}

private static ElasticInferenceServiceModel createModel(
Expand Down Expand Up @@ -383,7 +403,7 @@ public static InferenceServiceConfiguration get() {

return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(SERVICE_NAME)
.setTaskTypes(supportedTaskTypes)
.setTaskTypes(SUPPORTED_TASK_TYPES_FOR_SERVICES_API)
.setConfigurations(configurationMap)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
Expand Down Expand Up @@ -63,14 +64,24 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.useChatCompletionUrlMessage;
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.EMBEDDING_MAX_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.ORGANIZATION;

public class OpenAiService extends SenderService {
public static final String NAME = "openai";

private static final String SERVICE_NAME = "OpenAI";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
// The task types exposed via the _inference/_services API
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES_FOR_SERVICES_API = EnumSet.of(
TaskType.TEXT_EMBEDDING,
TaskType.COMPLETION,
TaskType.CHAT_COMPLETION
);
/**
* The task types that the {@link InferenceAction.Request} can accept.
*/
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);

public OpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
Expand Down Expand Up @@ -164,7 +175,7 @@ private static OpenAiModel createModel(
secretSettings,
context
);
case COMPLETION -> new OpenAiChatCompletionModel(
case COMPLETION, CHAT_COMPLETION -> new OpenAiChatCompletionModel(
inferenceEntityId,
taskType,
NAME,
Expand Down Expand Up @@ -236,7 +247,7 @@ public InferenceServiceConfiguration getConfiguration() {

@Override
public EnumSet<TaskType> supportedTaskTypes() {
return supportedTaskTypes;
return SUPPORTED_TASK_TYPES_FOR_SERVICES_API;
}

@Override
Expand All @@ -248,6 +259,15 @@ public void doInfer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (SUPPORTED_INFERENCE_ACTION_TASK_TYPES.contains(model.getTaskType()) == false) {
var responseString = ServiceUtils.unsupportedTaskTypeForInference(model, SUPPORTED_INFERENCE_ACTION_TASK_TYPES);

if (model.getTaskType() == TaskType.CHAT_COMPLETION) {
responseString = responseString + " " + useChatCompletionUrlMessage(model);
}
listener.onFailure(new ElasticsearchStatusException(responseString, RestStatus.BAD_REQUEST));
}

if (model instanceof OpenAiModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
Expand Down Expand Up @@ -356,7 +376,7 @@ public TransportVersion getMinimalSupportedVersion() {

@Override
public Set<TaskType> supportedStreamingTasks() {
return COMPLETION_ONLY;
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION, TaskType.ANY);
}

/**
Expand Down Expand Up @@ -444,7 +464,7 @@ public static InferenceServiceConfiguration get() {

return new InferenceServiceConfiguration.Builder().setService(NAME)
.setName(SERVICE_NAME)
.setTaskTypes(supportedTaskTypes)
.setTaskTypes(SUPPORTED_TASK_TYPES_FOR_SERVICES_API)
.setConfigurations(configurationMap)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,24 @@ private static <T> void blockingCall(
latch.await();
}

public static Model getInvalidModel(String inferenceEntityId, String serviceName) {
public static Model getInvalidModel(String inferenceEntityId, String serviceName, TaskType taskType) {
var mockConfigs = mock(ModelConfigurations.class);
when(mockConfigs.getInferenceEntityId()).thenReturn(inferenceEntityId);
when(mockConfigs.getService()).thenReturn(serviceName);
when(mockConfigs.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
when(mockConfigs.getTaskType()).thenReturn(taskType);

var mockModel = mock(Model.class);
when(mockModel.getInferenceEntityId()).thenReturn(inferenceEntityId);
when(mockModel.getConfigurations()).thenReturn(mockConfigs);
when(mockModel.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
when(mockModel.getTaskType()).thenReturn(taskType);

return mockModel;
}

public static Model getInvalidModel(String inferenceEntityId, String serviceName) {
return getInvalidModel(inferenceEntityId, serviceName, TaskType.TEXT_EMBEDDING);
}

public static SimilarityMeasure randomSimilarityMeasure() {
return randomFrom(SimilarityMeasure.values());
}
Expand Down
Loading

0 comments on commit f9a3721

Please sign in to comment.