From 72456324727f4c07afeb780e7076f00946128d04 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> Date: Thu, 9 Jan 2025 13:32:54 -0500 Subject: [PATCH] [ML] Unified schema API remove name field (#119799) * Removing name field * Fixing test --- .../inference/UnifiedCompletionRequest.java | 15 ++++----------- .../action/UnifiedCompletionRequestTests.java | 4 ---- .../external/http/sender/UnifiedChatInput.java | 10 +--------- .../UnifiedChatCompletionRequestEntity.java | 3 --- .../http/sender/UnifiedChatInputTests.java | 16 ++-------------- ...UnifiedChatCompletionRequestEntityTests.java | 1 - ...UnifiedChatCompletionRequestEntityTests.java | 1 - ...UnifiedChatCompletionRequestEntityTests.java | 17 +---------------- ...ticInferenceServiceCompletionModelTests.java | 2 +- .../services/openai/OpenAiServiceTests.java | 4 +--- .../OpenAiChatCompletionModelTests.java | 4 ++-- 11 files changed, 12 insertions(+), 65 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index e2f47f1a7a343..ae6a3bb10bd31 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -111,18 +111,14 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalFloat(topP); } - public record Message( - Content content, - String role, - @Nullable String name, - @Nullable String toolCallId, - @Nullable List toolCalls - ) implements Writeable { + public record Message(Content content, String role, @Nullable String toolCallId, @Nullable List toolCalls) + implements + Writeable { @SuppressWarnings("unchecked") static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( Message.class.getSimpleName(), - args -> new Message((Content) args[0], (String) args[1], (String) args[2], (String) args[3], (List) args[4]) + args -> new Message((Content) args[0], (String) args[1], (String) args[2], (List) args[3]) ); static { @@ -133,7 +129,6 @@ public record Message( ObjectParser.ValueType.VALUE_ARRAY ); PARSER.declareString(constructorArg(), new ParseField("role")); - PARSER.declareString(optionalConstructorArg(), new ParseField("name")); PARSER.declareString(optionalConstructorArg(), new ParseField("tool_call_id")); PARSER.declareObjectArray(optionalConstructorArg(), ToolCall.PARSER::apply, new ParseField("tool_calls")); } @@ -155,7 +150,6 @@ public Message(StreamInput in) throws IOException { in.readOptionalNamedWriteable(Content.class), in.readString(), in.readOptionalString(), - in.readOptionalString(), in.readOptionalCollectionAsList(ToolCall::new) ); } @@ -164,7 +158,6 @@ public Message(StreamInput in) throws IOException { public void writeTo(StreamOutput out) throws IOException { out.writeOptionalNamedWriteable(content); out.writeString(role); - out.writeOptionalString(name); out.writeOptionalString(toolCallId); out.writeOptionalCollection(toolCalls); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java index 47a0814a584b7..120c2a6dbc5e7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java @@ -35,7 +35,6 @@ public void testParseAllFields() throws IOException { "type": "string" } ], - "name": "a name", "tool_call_id": "100", "tool_calls": [ { @@ -83,7 +82,6 @@ public void testParseAllFields() throws IOException { List.of(new UnifiedCompletionRequest.ContentObject("some text", "string")) ), "user", - "a name", "100", List.of( new UnifiedCompletionRequest.ToolCall( @@ -155,7 +153,6 @@ public void testParsing() throws IOException { new UnifiedCompletionRequest.ContentString("What is the weather like in Boston today?"), "user", null, - null, null ) ), @@ -200,7 +197,6 @@ public static UnifiedCompletionRequest.Message randomMessage() { randomContent(), randomAlphaOfLength(10), randomAlphaOfLengthOrNull(10), - randomAlphaOfLengthOrNull(10), randomToolCallListOrNull() ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java index f89fa1ee37a6f..fceec7c431182 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java @@ -40,15 +40,7 @@ public UnifiedChatInput(List inputs, String roleValue, boolean stream) { private static List convertToMessages(List inputs, String roleValue) { return inputs.stream() - .map( - value -> new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString(value), - roleValue, - null, - null, - null - ) - ) + .map(value -> new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString(value), roleValue, null, null)) .toList(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java index 3ea8e28479ef2..5e6d09cde2b9f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java @@ -77,9 +77,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.field(ROLE_FIELD, message.role()); - if (message.name() != null) { - builder.field(NAME_FIELD, message.name()); - } if (message.toolCallId() != null) { builder.field(TOOL_CALL_ID_FIELD, message.toolCallId()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java index 42e1b18168aec..1c0643739d410 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java @@ -24,20 +24,8 @@ public void testConvertsStringInputToMessages() { Matchers.is( UnifiedCompletionRequest.of( List.of( - new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("hello"), - "a role", - null, - null, - null - ), - new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("awesome"), - "a role", - null, - null, - null - ) + new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "a role", null, null), + new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("awesome"), "a role", null, null) ) ) ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java index 75ff63e1314ac..15b4898650784 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java @@ -32,7 +32,6 @@ public void testModelUserFieldsSerialization() throws IOException { new UnifiedCompletionRequest.ContentString("Hello, world!"), ROLE, null, - null, null ); var messageList = new ArrayList(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java index f43b185391697..b0c58f3e94af8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java @@ -32,7 +32,6 @@ public void testModelUserFieldsSerialization() throws IOException { new UnifiedCompletionRequest.ContentString("Hello, world!"), ROLE, null, - null, null ); var messageList = new ArrayList(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntityTests.java index 0f305866ae988..d9388cab0e1ec 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntityTests.java @@ -39,7 +39,6 @@ public void testBasicSerialization() throws IOException { new UnifiedCompletionRequest.ContentString("Hello, world!"), ROLE, null, - null, null ); var messageList = new ArrayList(); @@ -78,7 +77,6 @@ public void testSerializationWithAllFields() throws IOException { UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( new UnifiedCompletionRequest.ContentString("Hello, world!"), ROLE, - "name", "tool_call_id", Collections.singletonList( new UnifiedCompletionRequest.ToolCall( @@ -127,7 +125,6 @@ public void testSerializationWithAllFields() throws IOException { { "content": "Hello, world!", "role": "user", - "name": "name", "tool_call_id": "tool_call_id", "tool_calls": [ { @@ -189,7 +186,6 @@ public void testSerializationWithNullOptionalFields() throws IOException { new UnifiedCompletionRequest.ContentString("Hello, world!"), ROLE, null, - null, null ); var messageList = new ArrayList(); @@ -240,7 +236,6 @@ public void testSerializationWithEmptyLists() throws IOException { new UnifiedCompletionRequest.ContentString("Hello, world!"), ROLE, null, - null, Collections.emptyList() // empty toolCalls list ); var messageList = new ArrayList(); @@ -290,7 +285,6 @@ public void testSerializationWithNestedObjects() throws IOException { Random random = Randomness.get(); String randomContent = "Hello, world! " + random.nextInt(1000); - String randomName = "name" + random.nextInt(1000); String randomToolCallId = "tool_call_id" + random.nextInt(1000); String randomArguments = "arguments" + random.nextInt(1000); String randomFunctionName = "function_name" + random.nextInt(1000); @@ -303,7 +297,6 @@ public void testSerializationWithNestedObjects() throws IOException { UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( new UnifiedCompletionRequest.ContentString(randomContent), ROLE, - randomName, randomToolCallId, Collections.singletonList( new UnifiedCompletionRequest.ToolCall( @@ -357,7 +350,6 @@ public void testSerializationWithNestedObjects() throws IOException { { "content": "%s", "role": "user", - "name": "%s", "tool_call_id": "%s", "tool_calls": [ { @@ -416,7 +408,6 @@ public void testSerializationWithNestedObjects() throws IOException { } """, randomContent, - randomName, randomToolCallId, randomArguments, randomFunctionName, @@ -449,11 +440,10 @@ public void testSerializationWithDifferentContentTypes() throws IOException { new UnifiedCompletionRequest.ContentString(randomContentString), ROLE, null, - null, null ); - UnifiedCompletionRequest.Message messageWithObjects = new UnifiedCompletionRequest.Message(contentObjects, ROLE, null, null, null); + UnifiedCompletionRequest.Message messageWithObjects = new UnifiedCompletionRequest.Message(contentObjects, ROLE, null, null); var messageList = new ArrayList(); messageList.add(messageWithString); messageList.add(messageWithObjects); @@ -502,7 +492,6 @@ public void testSerializationWithSpecialCharacters() throws IOException { UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( new UnifiedCompletionRequest.ContentString("Hello, world! \n \"Special\" characters: \t \\ /"), ROLE, - "name\nwith\nnewlines", "tool_call_id\twith\ttabs", Collections.singletonList( new UnifiedCompletionRequest.ToolCall( @@ -541,7 +530,6 @@ public void testSerializationWithSpecialCharacters() throws IOException { { "content": "Hello, world! \\n \\"Special\\" characters: \\t \\\\ /", "role": "user", - "name": "name\\nwith\\nnewlines", "tool_call_id": "tool_call_id\\twith\\ttabs", "tool_calls": [ { @@ -571,7 +559,6 @@ public void testSerializationWithBooleanFields() throws IOException { new UnifiedCompletionRequest.ContentString("Hello, world!"), ROLE, null, - null, null ); var messageList = new ArrayList(); @@ -641,7 +628,6 @@ public void testSerializationWithoutContentField() throws IOException { UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( null, "assistant", - "name\nwith\nnewlines", "tool_call_id\twith\ttabs", Collections.singletonList( new UnifiedCompletionRequest.ToolCall( @@ -669,7 +655,6 @@ public void testSerializationWithoutContentField() throws IOException { "messages": [ { "role": "assistant", - "name": "name\\nwith\\nnewlines", "tool_call_id": "tool_call_id\\twith\\ttabs", "tool_calls": [ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java index cc1463232e7e5..07da96cb32273 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java @@ -33,7 +33,7 @@ public void testOverridingModelId() { ); var request = new UnifiedCompletionRequest( - List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("message"), "user", null, null, null)), + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("message"), "user", null, null)), "new_model_id", null, null, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 03eacf17a8250..678c4528a3f41 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -967,9 +967,7 @@ public void testUnifiedCompletionInfer() throws Exception { service.unifiedCompletionInfer( model, UnifiedCompletionRequest.of( - List.of( - new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null, null) - ) + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) ), InferenceAction.Request.DEFAULT_TIMEOUT, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java index e7ac4cf879e92..2a5415f45c6d9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java @@ -51,7 +51,7 @@ public void testOverrideWith_NullMap() { public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() { var model = createChatCompletionModel("url", "org", "api_key", "model_name", "user"); var request = new UnifiedCompletionRequest( - List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null, null)), + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), "different_model", null, null, @@ -70,7 +70,7 @@ public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() { public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { var model = createChatCompletionModel("url", "org", "api_key", "model_name", "user"); var request = new UnifiedCompletionRequest( - List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null, null)), + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), null, // not overriding model null, null,