Skip to content

Commit

Permalink
[ML] Unified schema API remove name field (elastic#119799)
Browse files Browse the repository at this point in the history
* Removing name field

* Fixing test
  • Loading branch information
jonathan-buttner authored Jan 9, 2025
1 parent 75d1050 commit 7245632
Show file tree
Hide file tree
Showing 11 changed files with 12 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,14 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalFloat(topP);
}

public record Message(
Content content,
String role,
@Nullable String name,
@Nullable String toolCallId,
@Nullable List<ToolCall> toolCalls
) implements Writeable {
public record Message(Content content, String role, @Nullable String toolCallId, @Nullable List<ToolCall> toolCalls)
implements
Writeable {

@SuppressWarnings("unchecked")
static final ConstructingObjectParser<Message, Void> PARSER = new ConstructingObjectParser<>(
Message.class.getSimpleName(),
args -> new Message((Content) args[0], (String) args[1], (String) args[2], (String) args[3], (List<ToolCall>) args[4])
args -> new Message((Content) args[0], (String) args[1], (String) args[2], (List<ToolCall>) args[3])
);

static {
Expand All @@ -133,7 +129,6 @@ public record Message(
ObjectParser.ValueType.VALUE_ARRAY
);
PARSER.declareString(constructorArg(), new ParseField("role"));
PARSER.declareString(optionalConstructorArg(), new ParseField("name"));
PARSER.declareString(optionalConstructorArg(), new ParseField("tool_call_id"));
PARSER.declareObjectArray(optionalConstructorArg(), ToolCall.PARSER::apply, new ParseField("tool_calls"));
}
Expand All @@ -155,7 +150,6 @@ public Message(StreamInput in) throws IOException {
in.readOptionalNamedWriteable(Content.class),
in.readString(),
in.readOptionalString(),
in.readOptionalString(),
in.readOptionalCollectionAsList(ToolCall::new)
);
}
Expand All @@ -164,7 +158,6 @@ public Message(StreamInput in) throws IOException {
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalNamedWriteable(content);
out.writeString(role);
out.writeOptionalString(name);
out.writeOptionalString(toolCallId);
out.writeOptionalCollection(toolCalls);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ public void testParseAllFields() throws IOException {
"type": "string"
}
],
"name": "a name",
"tool_call_id": "100",
"tool_calls": [
{
Expand Down Expand Up @@ -83,7 +82,6 @@ public void testParseAllFields() throws IOException {
List.of(new UnifiedCompletionRequest.ContentObject("some text", "string"))
),
"user",
"a name",
"100",
List.of(
new UnifiedCompletionRequest.ToolCall(
Expand Down Expand Up @@ -155,7 +153,6 @@ public void testParsing() throws IOException {
new UnifiedCompletionRequest.ContentString("What is the weather like in Boston today?"),
"user",
null,
null,
null
)
),
Expand Down Expand Up @@ -200,7 +197,6 @@ public static UnifiedCompletionRequest.Message randomMessage() {
randomContent(),
randomAlphaOfLength(10),
randomAlphaOfLengthOrNull(10),
randomAlphaOfLengthOrNull(10),
randomToolCallListOrNull()
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,7 @@ public UnifiedChatInput(List<String> inputs, String roleValue, boolean stream) {

private static List<UnifiedCompletionRequest.Message> convertToMessages(List<String> inputs, String roleValue) {
return inputs.stream()
.map(
value -> new UnifiedCompletionRequest.Message(
new UnifiedCompletionRequest.ContentString(value),
roleValue,
null,
null,
null
)
)
.map(value -> new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString(value), roleValue, null, null))
.toList();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}

builder.field(ROLE_FIELD, message.role());
if (message.name() != null) {
builder.field(NAME_FIELD, message.name());
}
if (message.toolCallId() != null) {
builder.field(TOOL_CALL_ID_FIELD, message.toolCallId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,8 @@ public void testConvertsStringInputToMessages() {
Matchers.is(
UnifiedCompletionRequest.of(
List.of(
new UnifiedCompletionRequest.Message(
new UnifiedCompletionRequest.ContentString("hello"),
"a role",
null,
null,
null
),
new UnifiedCompletionRequest.Message(
new UnifiedCompletionRequest.ContentString("awesome"),
"a role",
null,
null,
null
)
new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "a role", null, null),
new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("awesome"), "a role", null, null)
)
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ public void testModelUserFieldsSerialization() throws IOException {
new UnifiedCompletionRequest.ContentString("Hello, world!"),
ROLE,
null,
null,
null
);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ public void testModelUserFieldsSerialization() throws IOException {
new UnifiedCompletionRequest.ContentString("Hello, world!"),
ROLE,
null,
null,
null
);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ public void testBasicSerialization() throws IOException {
new UnifiedCompletionRequest.ContentString("Hello, world!"),
ROLE,
null,
null,
null
);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
Expand Down Expand Up @@ -78,7 +77,6 @@ public void testSerializationWithAllFields() throws IOException {
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
new UnifiedCompletionRequest.ContentString("Hello, world!"),
ROLE,
"name",
"tool_call_id",
Collections.singletonList(
new UnifiedCompletionRequest.ToolCall(
Expand Down Expand Up @@ -127,7 +125,6 @@ public void testSerializationWithAllFields() throws IOException {
{
"content": "Hello, world!",
"role": "user",
"name": "name",
"tool_call_id": "tool_call_id",
"tool_calls": [
{
Expand Down Expand Up @@ -189,7 +186,6 @@ public void testSerializationWithNullOptionalFields() throws IOException {
new UnifiedCompletionRequest.ContentString("Hello, world!"),
ROLE,
null,
null,
null
);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
Expand Down Expand Up @@ -240,7 +236,6 @@ public void testSerializationWithEmptyLists() throws IOException {
new UnifiedCompletionRequest.ContentString("Hello, world!"),
ROLE,
null,
null,
Collections.emptyList() // empty toolCalls list
);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
Expand Down Expand Up @@ -290,7 +285,6 @@ public void testSerializationWithNestedObjects() throws IOException {
Random random = Randomness.get();

String randomContent = "Hello, world! " + random.nextInt(1000);
String randomName = "name" + random.nextInt(1000);
String randomToolCallId = "tool_call_id" + random.nextInt(1000);
String randomArguments = "arguments" + random.nextInt(1000);
String randomFunctionName = "function_name" + random.nextInt(1000);
Expand All @@ -303,7 +297,6 @@ public void testSerializationWithNestedObjects() throws IOException {
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
new UnifiedCompletionRequest.ContentString(randomContent),
ROLE,
randomName,
randomToolCallId,
Collections.singletonList(
new UnifiedCompletionRequest.ToolCall(
Expand Down Expand Up @@ -357,7 +350,6 @@ public void testSerializationWithNestedObjects() throws IOException {
{
"content": "%s",
"role": "user",
"name": "%s",
"tool_call_id": "%s",
"tool_calls": [
{
Expand Down Expand Up @@ -416,7 +408,6 @@ public void testSerializationWithNestedObjects() throws IOException {
}
""",
randomContent,
randomName,
randomToolCallId,
randomArguments,
randomFunctionName,
Expand Down Expand Up @@ -449,11 +440,10 @@ public void testSerializationWithDifferentContentTypes() throws IOException {
new UnifiedCompletionRequest.ContentString(randomContentString),
ROLE,
null,
null,
null
);

UnifiedCompletionRequest.Message messageWithObjects = new UnifiedCompletionRequest.Message(contentObjects, ROLE, null, null, null);
UnifiedCompletionRequest.Message messageWithObjects = new UnifiedCompletionRequest.Message(contentObjects, ROLE, null, null);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
messageList.add(messageWithString);
messageList.add(messageWithObjects);
Expand Down Expand Up @@ -502,7 +492,6 @@ public void testSerializationWithSpecialCharacters() throws IOException {
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
new UnifiedCompletionRequest.ContentString("Hello, world! \n \"Special\" characters: \t \\ /"),
ROLE,
"name\nwith\nnewlines",
"tool_call_id\twith\ttabs",
Collections.singletonList(
new UnifiedCompletionRequest.ToolCall(
Expand Down Expand Up @@ -541,7 +530,6 @@ public void testSerializationWithSpecialCharacters() throws IOException {
{
"content": "Hello, world! \\n \\"Special\\" characters: \\t \\\\ /",
"role": "user",
"name": "name\\nwith\\nnewlines",
"tool_call_id": "tool_call_id\\twith\\ttabs",
"tool_calls": [
{
Expand Down Expand Up @@ -571,7 +559,6 @@ public void testSerializationWithBooleanFields() throws IOException {
new UnifiedCompletionRequest.ContentString("Hello, world!"),
ROLE,
null,
null,
null
);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
Expand Down Expand Up @@ -641,7 +628,6 @@ public void testSerializationWithoutContentField() throws IOException {
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
null,
"assistant",
"name\nwith\nnewlines",
"tool_call_id\twith\ttabs",
Collections.singletonList(
new UnifiedCompletionRequest.ToolCall(
Expand Down Expand Up @@ -669,7 +655,6 @@ public void testSerializationWithoutContentField() throws IOException {
"messages": [
{
"role": "assistant",
"name": "name\\nwith\\nnewlines",
"tool_call_id": "tool_call_id\\twith\\ttabs",
"tool_calls": [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public void testOverridingModelId() {
);

var request = new UnifiedCompletionRequest(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("message"), "user", null, null, null)),
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("message"), "user", null, null)),
"new_model_id",
null,
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -967,9 +967,7 @@ public void testUnifiedCompletionInfer() throws Exception {
service.unifiedCompletionInfer(
model,
UnifiedCompletionRequest.of(
List.of(
new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null, null)
)
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
),
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public void testOverrideWith_NullMap() {
public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() {
var model = createChatCompletionModel("url", "org", "api_key", "model_name", "user");
var request = new UnifiedCompletionRequest(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null, null)),
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
"different_model",
null,
null,
Expand All @@ -70,7 +70,7 @@ public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() {
public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() {
var model = createChatCompletionModel("url", "org", "api_key", "model_name", "user");
var request = new UnifiedCompletionRequest(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null, null)),
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
null, // not overriding model
null,
null,
Expand Down

0 comments on commit 7245632

Please sign in to comment.